{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - PBSP Scoring - Page 3" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import os\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as mtick\n", "from qdrant_client import QdrantClient\n", "from qdrant_client.http.models import Distance, VectorParams, Batch, Filter, FieldCondition, Range, MatchValue\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "import nltk\n", "from nltk import sent_tokenize\n", "from sklearn.feature_extraction import text\n", "from pprint import pprint\n", "import re\n", "from flair.embeddings import TransformerDocumentEmbeddings\n", "from flair.data import Sentence\n", "from flair.models import TARSClassifier\n", "from sentence_transformers import SentenceTransformer, util\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "from typing import Dict\n", "from setfit import SetFitModel\n", "from tqdm import tqdm\n", "import time\n", "for i in tqdm(range(15), disable=True):\n", " time.sleep(1)" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "embedding = TransformerDocumentEmbeddings('distilbert-base-uncased')\n", "client = QdrantClient(\n", " host=os.environ[\"QDRANT_API_URL\"], \n", " api_key=os.environ[\"QDRANT_API_KEY\"],\n", " timeout=60,\n", " port=443\n", ")\n", "collection_name = \"my_collection\"\n", "model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')\n", "vector_dim = 384 #{distilbert-base-uncased: 768, multi-qa-MiniLM-L6-cos-v1:384}\n", "sf_func_model_name = \"setfit-zero-shot-classification-pbsp-p3-func\"\n", "sf_func_model = SetFitModel.from_pretrained(f\"aammari/{sf_func_model_name}\")\n", "tars_model_path = 'few-shot-model-gain-avoid'\n", "tars = TARSClassifier().load(tars_model_path+'/best-model.pt')\n", "\n", "# download nltk 'punkt' if not available\n", "try:\n", " nltk.data.find('tokenizers/punkt')\n", "except LookupError:\n", " nltk.download('punkt')\n", "\n", "# download nltk 'averaged_perceptron_tagger' if not available\n", "try:\n", " nltk.data.find('taggers/averaged_perceptron_tagger')\n", "except LookupError:\n", " nltk.download('averaged_perceptron_tagger')\n", " \n", "#argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "markdown", "id": "84add56f", "metadata": { "hide_input": true }, "source": [ "### Domain Expert Section\n", "#### Enter the Topic Glossary" ] }, { "cell_type": "code", "execution_count": null, "id": "17fe501c", "metadata": { "hide_input": false }, "outputs": [], "source": [ "bhvr_onto_lst = [\n", " 'hit employees',\n", " 'push people',\n", " 'throw objects',\n", " 'beat students' \n", "]\n", "bhvr_onto_text_input = widgets.Textarea(\n", " value='\\n'.join(bhvr_onto_lst),\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "bhvr_onto_label = widgets.Label(value='Behaviours')\n", "bhvr_onto_box = widgets.VBox([bhvr_onto_label, bhvr_onto_text_input], \n", " layout={'width': '400px', 'height': '150px'})" ] }, { "cell_type": "code", "execution_count": null, "id": "7fa6ce86", "metadata": {}, "outputs": [], "source": [ "fh_onto_lst = [\n", " 'Gain the teacher attention',\n", " 'Complete work in class',\n", " 'Avoid difficult work'\n", "]\n", "\n", "fh_onto_text_input = widgets.Textarea(\n", " value='\\n'.join(fh_onto_lst),\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "fh_onto_label = widgets.Label(value='Functional Hypothesis')\n", "fh_onto_box = widgets.VBox([fh_onto_label, fh_onto_text_input], \n", " layout={'width': '400px', 'height': '150px'})" ] }, { "cell_type": "code", "execution_count": null, "id": "20a1c75c", "metadata": { "scrolled": true }, "outputs": [], "source": [ "rep_onto_lst = [\n", " 'Ask teacher for help',\n", " 'Replace full body slam',\n", " 'Use a next sign'\n", "]\n", "\n", "rep_onto_text_input = widgets.Textarea(\n", " value='\\n'.join(rep_onto_lst),\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "rep_onto_label = widgets.Label(value='Replacement Behaviour')\n", "rep_onto_box = widgets.VBox([rep_onto_label, rep_onto_text_input], \n", " layout={'width': '400px', 'height': '150px'})\n", "\n", "#onto_boxes = widgets.HBox([bhvr_onto_box, fh_onto_box, rep_onto_box], \n", "# layout={'width': '90%', 'height': '150px'})\n", "\n", "onto_boxes = widgets.HBox([fh_onto_box], \n", " layout={'width': '90%', 'height': '150px'})\n", "\n", "display(onto_boxes)" ] }, { "cell_type": "code", "execution_count": null, "id": "72c2c6f9", "metadata": { "hide_input": false }, "outputs": [], "source": [ "#Text Preprocessing\n", "try:\n", " nlp = spacy.load('en_core_web_sm')\n", "except OSError:\n", " spacy.cli.download('en_core_web_sm')\n", " nlp = spacy.load('en_core_web_sm')\n", "sw_lst = text.ENGLISH_STOP_WORDS\n", "def preprocess(onto_lst):\n", " cleaned_onto_lst = []\n", " pattern = re.compile(r'^[a-z ]*$')\n", " for document in onto_lst:\n", " text = []\n", " doc = nlp(document)\n", " person_tokens = []\n", " for w in doc:\n", " if w.ent_type_ == 'PERSON':\n", " person_tokens.append(w.lemma_)\n", " for w in doc:\n", " if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:\n", " text.append(w.lemma_.lower())\n", " texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]\n", " cleaned_onto_lst.append(\" \".join(texts))\n", " return cleaned_onto_lst\n", "\n", "cl_bhvr_onto_lst = preprocess(bhvr_onto_lst)\n", "cl_fh_onto_lst = preprocess(fh_onto_lst)\n", "cl_rep_onto_lst = preprocess(rep_onto_lst)\n", "\n", "#pprint(cl_bhvr_onto_lst)\n", "#pprint(cl_fh_onto_lst)\n", "#pprint(cl_rep_onto_lst)" ] }, { "cell_type": "code", "execution_count": null, "id": "a1f934eb", "metadata": {}, "outputs": [], "source": [ "#compute document embeddings\n", "\n", "# distilbert-base-uncased from Flair\n", "def embeddings(cl_onto_lst):\n", " emb_onto_lst = []\n", " for doc in cl_onto_lst:\n", " sentence = Sentence(doc)\n", " embedding.embed(sentence)\n", " emb_onto_lst.append(sentence.embedding.tolist())\n", " return emb_onto_lst\n", "\n", "# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers\n", "def sentence_embeddings(cl_onto_lst):\n", " emb_onto_lst_temp = model.encode(cl_onto_lst)\n", " emb_onto_lst = [x.tolist() for x in emb_onto_lst_temp]\n", " return emb_onto_lst\n", "\n", "'''\n", "emb_bhvr_onto_lst = embeddings(cl_bhvr_onto_lst)\n", "emb_fh_onto_lst = embeddings(cl_fh_onto_lst)\n", "emb_rep_onto_lst = embeddings(cl_rep_onto_lst)\n", "'''\n", "\n", "emb_bhvr_onto_lst = sentence_embeddings(cl_bhvr_onto_lst)\n", "emb_fh_onto_lst = sentence_embeddings(cl_fh_onto_lst)\n", "emb_rep_onto_lst = sentence_embeddings(cl_rep_onto_lst)" ] }, { "cell_type": "code", "execution_count": null, "id": "6302e312", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#add to qdrant collection\n", "def add_to_collection():\n", " global cl_bhvr_onto_lst, emb_bhvr_onto_lst, cl_fh_onto_lst, emb_fh_onto_lst, cl_rep_onto_lst, emb_rep_onto_lst\n", " client.recreate_collection(\n", " collection_name=collection_name,\n", " vectors_config=VectorParams(size=vector_dim, distance=Distance.COSINE),\n", " )\n", " doc_count = len(emb_bhvr_onto_lst) + len(emb_fh_onto_lst) + len(emb_rep_onto_lst)\n", " ids = list(range(1, doc_count+1))\n", " payloads = [{\"ontology\": \"behaviours\", \"phrase\": x} for x in cl_bhvr_onto_lst] + \\\n", " [{\"ontology\": \"functional_hypothesis\", \"phrase\": y} for y in cl_fh_onto_lst] + \\\n", " [{\"ontology\": \"replacement_behaviour\", \"phrase\": z} for z in cl_rep_onto_lst]\n", " vectors = emb_bhvr_onto_lst+emb_fh_onto_lst+emb_rep_onto_lst\n", " client.upsert(\n", " collection_name=f\"{collection_name}\",\n", " points=Batch(\n", " ids=ids,\n", " payloads=payloads,\n", " vectors=vectors\n", " ),\n", " )\n", "\n", "def count_collection():\n", " return len(client.scroll(\n", " collection_name=f\"{collection_name}\"\n", " )[0])\n", "\n", "add_to_collection()\n", "point_count = count_collection()\n", "#print(point_count)" ] }, { "cell_type": "code", "execution_count": null, "id": "b74861d4", "metadata": {}, "outputs": [], "source": [ "query_filter=Filter(\n", " must=[ \n", " FieldCondition(\n", " key='ontology',\n", " match=MatchValue(value=\"functional_hypothesis\")# Condition based on values of `rand_number` field.\n", " )\n", " ]\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "40070313", "metadata": {}, "outputs": [], "source": [ "#verb phrase extraction\n", "def extract_vbs(data_chunked):\n", " for tup in data_chunked:\n", " if len(tup) > 2:\n", " yield(str(\" \".join(str(x[0]) for x in tup)))\n", "\n", "def get_verb_phrases(nltk_query):\n", " data_tok = nltk.word_tokenize(nltk_query) #tokenisation\n", " data_pos = nltk.pos_tag(data_tok) #POS tagging\n", " cfgs = [\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\"\n", " ]\n", " vbs = []\n", " for cfg_1 in cfgs: \n", " chunker = nltk.RegexpParser(cfg_1)\n", " data_chunked = chunker.parse(data_pos)\n", " vbs += extract_vbs(data_chunked)\n", " return vbs" ] }, { "cell_type": "code", "execution_count": null, "id": "1550437b", "metadata": {}, "outputs": [], "source": [ "#query and get score\n", "\n", "# distilbert-base-uncased from Flair\n", "def get_query_vector(query):\n", " sentence = Sentence(query)\n", " embedding.embed(sentence)\n", " query_vector = sentence.embedding.tolist()\n", " return query_vector\n", "\n", "# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers\n", "def sentence_get_query_vector(query):\n", " query_vector = model.encode(query)\n", " return query_vector\n", "\n", "def search_collection(ontology, query_vector):\n", " query_filter=Filter(\n", " must=[ \n", " FieldCondition(\n", " key='ontology',\n", " match=MatchValue(value=ontology)\n", " )\n", " ]\n", " )\n", " \n", " hits = client.search(\n", " collection_name=f\"{collection_name}\",\n", " query_vector=query_vector,\n", " query_filter=query_filter, \n", " append_payload=True, \n", " limit=point_count \n", " )\n", " return hits\n", "\n", "semantic_passing_score = 0.50\n", "\n", "\n", "#ontology = 'behaviours'\n", "#query = 'punch father face'\n", "#query_vector = sentence_get_query_vector(query)\n", "#hist = search_collection(ontology, query_vector)" ] }, { "cell_type": "code", "execution_count": null, "id": "02fda761", "metadata": {}, "outputs": [], "source": [ "# format output\n", "def color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#ADD8E6')\n", "\n", "def annotate_query(highlights, query):\n", " ents = []\n", " for h in highlights:\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": 'GLOSSARY'}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "79b519a6", "metadata": {}, "outputs": [], "source": [ "#setfit sentence extraction\n", "def extract_sentences(nltk_query):\n", " sentences = sent_tokenize(nltk_query)\n", " return sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "38594d91", "metadata": {}, "outputs": [], "source": [ "def convert_df(result_df):\n", " new_df = pd.DataFrame(columns=['text', 'prediction'])\n", " new_df['text'] = result_df['Phrase']\n", " new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], min(row['Score'], 1.0)]], axis=1)\n", " return new_df" ] }, { "cell_type": "code", "execution_count": null, "id": "1e575a22", "metadata": {}, "outputs": [], "source": [ "def custom_f1(data: Dict[str, float], title: str):\n", " from plotly.subplots import make_subplots\n", " import plotly.colors\n", " import random\n", "\n", " fig = make_subplots(\n", " rows=2,\n", " cols=1,\n", " subplot_titles=[ \"Overall Model Score\", \"Model Score By Category\", ],\n", " )\n", "\n", " x = ['precision', 'recall', 'f1']\n", " macro_data = [v for k, v in data.items() if \"macro\" in k]\n", " fig.add_bar(\n", " x=x,\n", " y=macro_data,\n", " row=1,\n", " col=1,\n", " )\n", " per_label = {\n", " k: v\n", " for k, v in data.items()\n", " if all(key not in k for key in [\"macro\", \"micro\", \"support\"])\n", " }\n", "\n", " num_labels = int(len(per_label.keys())/3)\n", " fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]\n", " colors = random.sample(fixed_colors, num_labels)\n", "\n", " fig.add_bar(\n", " x=[k for k, v in per_label.items()],\n", " y=[v for k, v in per_label.items()],\n", " row=2,\n", " col=1,\n", " marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]\n", " )\n", " fig.update_layout(showlegend=False, title_text=title)\n", "\n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "07e3ff7b", "metadata": {}, "outputs": [], "source": [ "def get_null_class_df(sentences, result_df):\n", " sents = result_df['Phrase'].tolist()\n", " null_sents = [x for x in sentences if x not in sents]\n", " topics = ['NO FUNCTION'] * len(null_sents)\n", " scores = [0.90] * len(null_sents)\n", " null_df = pd.DataFrame({'Phrase': null_sents, 'Topic': topics, 'Score': scores})\n", " return null_df" ] }, { "cell_type": "code", "execution_count": null, "id": "0d5f8d3c", "metadata": {}, "outputs": [], "source": [ "#setfit func query and get predicted topic\n", "\n", "def get_sf_func_topic(sentences):\n", " preds = list(sf_func_model(sentences))\n", " return preds\n", "def get_sf_func_topic_scores(sentences):\n", " preds = sf_func_model.predict_proba(sentences)\n", " preds = [max(list(x)) for x in preds]\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "id": "67bf2154", "metadata": {}, "outputs": [], "source": [ "# setfit func format output\n", "ind_func_topic_dict = {\n", " 0: 'NO FUNCTION',\n", " 1: 'FUNCTION',\n", " }\n", "\n", "highlight_threshold = 0.25\n", "passing_score = 0.50\n", "\n", "def sf_func_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#CCFFCC')\n", "\n", "def sf_annotate_query(highlights, query, topics):\n", " ents = []\n", " query = query.strip() # remove newline characters from the query string\n", " for h, t in zip(highlights, topics):\n", " h = re.escape(h) # escape special characters in the highlights string\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "316bd9e2", "metadata": {}, "outputs": [], "source": [ "#query and get predicted topic\n", "\n", "p_classes = {'gain_attention': 0,\n", " 'avoid_attention': 1,\n", " 'unknown': 2\n", " }\n", "def get_topic(sentences):\n", " preds = []\n", " for t in sentences:\n", " sentence = Sentence(t)\n", " tars.predict(sentence)\n", " try:\n", " pred = p_classes[sentence.tag]\n", " except:\n", " pred = 2\n", " preds.append(pred)\n", " return preds\n", "def get_topic_scores(sentences):\n", " preds = []\n", " for t in sentences:\n", " sentence = Sentence(t)\n", " tars.predict(sentence)\n", " try:\n", " pred = sentence.score\n", " except:\n", " pred = 0.75\n", " preds.append(pred)\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "id": "1cadde44", "metadata": {}, "outputs": [], "source": [ "# format output\n", "ind_topic_dict = {\n", " 0: 'GAIN-ATTENTION',\n", " 1: 'AVOID-ATTENTION',\n", " 2: 'UNKNOWN'\n", " }\n", "\n", "topic_color_dict = {\n", " 'GAIN-ATTENTION': '#FFCCCC',\n", " 'AVOID-ATTENTION': '#CCFFFF'\n", " }\n", "\n", "gain_avoid_passing_score = 0.25\n", "\n", "def gain_avoid_color(df, color):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)\n", "\n", "def gain_avoid_annotate_query(highlights, query, topics):\n", " ents = []\n", " for h, t in zip(highlights, topics):\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "51dbd744", "metadata": {}, "outputs": [], "source": [ "def path_to_image_html(path):\n", " return ''\n", "\n", "final_passing = 0.0\n", "def display_final_df(agg_df):\n", " tags = []\n", " crits = [\n", " 'GAIN-ATTENTION',\n", " 'AVOID-ATTENTION'\n", " ]\n", " orig_crits = crits\n", " crits = [x for x in crits if x in agg_df.index.tolist()]\n", " bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]\n", " paths = ['./tick_green.png' if x else './cross_red.png' for x in bools]\n", " df = pd.DataFrame({'Topic': crits, 'USED': paths})\n", " rem_crits = [x for x in orig_crits if x not in crits]\n", " if len(rem_crits) > 0:\n", " df2 = pd.DataFrame({'Topic': rem_crits, 'USED': ['./cross_red.png'] * len(rem_crits)})\n", " df = pd.concat([df, df2])\n", " df = df.set_index('Topic')\n", " pd.set_option('display.max_colwidth', None)\n", " display(HTML('
' + df.to_html(classes=[\"align-center\"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '
'))" ] }, { "cell_type": "markdown", "id": "2c6e9fe7", "metadata": {}, "source": [ "### Practitioner Section\n", "#### Enter a summary statement outlining the functional hypothesis" ] }, { "cell_type": "code", "execution_count": null, "id": "76dd8cab", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#demo with Voila\n", "\n", "func_label = widgets.Label(value='Please type your answer:')\n", "func_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '300px', 'width': '90%'}\n", ")\n", "\n", "func_nlp_btn = widgets.Button(\n", " description='Score Functions',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Score Functions',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "gain_avoid_nlp_btn = widgets.Button(\n", " description='Detect Gain / Avoid',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Detect Gain / Avoid',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_agr_btn = widgets.Button(\n", " description='Validate Data',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Validate Data',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_eval_btn = widgets.Button(\n", " description='Evaluate Model',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Evaluate Model',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "btn_box = widgets.HBox([bhvr_agr_btn, bhvr_eval_btn], \n", " layout={'width': '100%', 'height': '160%'})\n", "func_btn_box = widgets.HBox([func_nlp_btn, gain_avoid_nlp_btn], \n", " layout={'width': '100%', 'height': '160%'})\n", "func_outt = widgets.Output()\n", "func_outt.layout.height = '100%'\n", "func_outt.layout.width = '100%'\n", "func_box = widgets.VBox([func_text_input, func_btn_box, btn_box, func_outt], \n", " layout={'width': '100%', 'height': '160%'})\n", "dataset_rg_name = 'pbsp-page3-func-argilla-ds'\n", "agrilla_df = None\n", "annotated = False\n", "sub_2_result_dfs = []\n", "def on_func_button_next(b):\n", " global fh_onto_lst, cl_fh_onto_lst, emb_fh_onto_lst, agrilla_df\n", " with func_outt:\n", " clear_output()\n", " fh_onto_lst = fh_onto_text_input.value.split(\"\\n\")\n", " cl_fh_onto_lst = preprocess(fh_onto_lst)\n", " orig_cl_dict = {x:y for x,y in zip(cl_fh_onto_lst, fh_onto_lst)}\n", " emb_fh_onto_lst = sentence_embeddings(cl_fh_onto_lst)\n", " add_to_collection()\n", " query = func_text_input.value\n", " vbs = get_verb_phrases(query)\n", " cl_vbs = preprocess(vbs)\n", " emb_vbs = sentence_embeddings(cl_vbs)\n", " vb_ind = -1\n", " highlights = []\n", " highlight_scores = []\n", " result_dfs = []\n", " for query_vector in emb_vbs:\n", " vb_ind += 1\n", " hist = search_collection('functional_hypothesis', query_vector)\n", " hist_dict = [dict(x) for x in hist]\n", " scores = [x['score'] for x in hist_dict]\n", " payloads = [orig_cl_dict[x['payload']['phrase']] for x in hist_dict]\n", " result_df = pd.DataFrame({'Score': scores, 'Glossary': payloads})\n", " result_df = result_df[result_df['Score'] >= semantic_passing_score]\n", " if len(result_df) > 0:\n", " highlights.append(vbs[vb_ind])\n", " highlight_scores.append(result_df.Score.max())\n", " result_df['Phrase'] = [vbs[vb_ind]] * len(result_df)\n", " result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_dfs.append(result_df)\n", " else:\n", " continue\n", " ents = []\n", " colors = {}\n", " if len(highlights) > 0:\n", " ents = annotate_query(highlights, query)\n", " for ent in ents:\n", " colors[ent['label']] = '#ADD8E6'\n", " \n", " #setfit function\n", " sentences = extract_sentences(query)\n", " cl_sentences = preprocess(sentences)\n", " topic_inds = get_sf_func_topic(cl_sentences)\n", " topics = [ind_func_topic_dict[i] for i in topic_inds]\n", " scores = get_sf_func_topic_scores(cl_sentences)\n", " sf_func_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " sf_func_sub_result_df = sf_func_result_df[sf_func_result_df['Topic'] == 'FUNCTION']\n", " sub_2_result_df = sf_func_result_df[sf_func_result_df['Topic'] == 'NO FUNCTION']\n", " sub_2_result_df = pd.concat([sub_2_result_df, sf_func_sub_result_df]).reset_index(drop=True)\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " sf_func_highlights = []\n", " sf_func_ents = []\n", " if len(sf_func_sub_result_df) > 0:\n", " sf_func_highlights = sf_func_sub_result_df['Phrase'].tolist()\n", " sf_func_highlight_topics = sf_func_sub_result_df['Topic'].tolist()\n", " sf_func_highlight_scores = sf_func_sub_result_df['Score'].tolist() \n", " sf_func_ents = sf_annotate_query(sf_func_highlights, query, sf_func_highlight_topics)\n", " for ent, hs in zip(sf_func_ents, sf_func_highlight_scores):\n", " if hs >= passing_score:\n", " colors[ent['label']] = '#CCFFCC'\n", " else:\n", " colors[ent['label']] = '#FFCC66'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " if len(sf_func_ents) > 0:\n", " ents = ents + sf_func_ents\n", " \n", " ex = [{\"text\": query,\n", " \"ents\": ents,\n", " \"title\": None}]\n", " if len(ents) > 0:\n", " title = \"Answer Highlights\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, options=options)\n", " display(HTML(html))\n", " if len(result_dfs) > 0:\n", " title = \"Similar to Glossary\"\n", " display(HTML(f'

{title}

'))\n", " result_df = pd.concat(result_dfs).reset_index(drop = True)\n", " result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " sub_2_result_df = result_df.copy()\n", " sub_2_result_df['Topic'] = ['FUNCTION'] * len(result_df)\n", " sub_2_result_df = sub_2_result_df[['Phrase', 'Topic', 'Score']].drop_duplicates().reset_index(drop=True)\n", " null_df = get_null_class_df(vbs, sub_2_result_df)\n", " if len(null_df) > 0:\n", " sub_2_result_df = pd.concat([sub_2_result_df, null_df]).reset_index(drop=True)\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " agg_df = result_df.groupby(result_df.Phrase).max()\n", " agg_df['Phrase'] = agg_df.index\n", " agg_df = agg_df.reset_index(drop=True)\n", " agg_df = agg_df.drop(columns=['Glossary'])\n", " result_df = pd.merge(result_df, agg_df, 'inner', ['Phrase', 'Score'])\n", " result_df = result_df[['Phrase', 'Glossary', 'Score']]\n", " result_df = result_df.set_index('Phrase')\n", " display(color(result_df))\n", " if len(sf_func_sub_result_df) > 0:\n", " title = \"Detected Functions\"\n", " display(HTML(f'

{title}

'))\n", " result_df = sf_func_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_df = result_df.set_index('Phrase')\n", " display(sf_func_color(result_df))\n", " if len(sub_2_result_dfs) > 0:\n", " sub_2_result_df = pd.concat(sub_2_result_dfs).reset_index(drop=True)\n", " agrilla_df = sub_2_result_df.copy()\n", "\n", "def on_gain_avoid_button_next(b):\n", " global agrilla_df\n", " with func_outt:\n", " clear_output()\n", " query = func_text_input.value\n", " sentences = extract_sentences(query)\n", " cl_sentences = preprocess(sentences)\n", " topic_inds = get_topic(cl_sentences)\n", " topics = [ind_topic_dict[i] for i in topic_inds]\n", " scores = get_topic_scores(cl_sentences)\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " sub_result_df = result_df[(result_df['Score'] >= gain_avoid_passing_score) & (result_df['Topic'] != 'UNKNOWN')]\n", " sub_2_result_df = result_df[result_df['Topic'] == 'UNKNOWN']\n", " highlights = []\n", " if len(sub_result_df) > 0:\n", " highlights = sub_result_df['Phrase'].tolist()\n", " highlight_topics = sub_result_df['Topic'].tolist() \n", " ents = gain_avoid_annotate_query(highlights, query, highlight_topics)\n", " colors = {}\n", " for ent, ht in zip(ents, highlight_topics):\n", " colors[ent['label']] = topic_color_dict[ht]\n", "\n", " ex = [{\"text\": query,\n", " \"ents\": ents,\n", " \"title\": None}]\n", " title = \"Gaining & Avoidance Highlights\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, jupyter=True, options={'colors': colors})\n", " display(HTML(html))\n", " title = \"Used Approach Classifications\"\n", " display(HTML(f'

{title}

'))\n", " for top in topic_color_dict.keys():\n", " top_result_df = sub_result_df[sub_result_df['Topic'] == top]\n", " if len(top_result_df) > 0:\n", " top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " top_result_df = top_result_df.set_index('Phrase')\n", " top_result_df = top_result_df[['Score']]\n", " display(HTML(\n", " f'

{top}

'))\n", " display(gain_avoid_color(top_result_df, topic_color_dict[top]))\n", " \n", " agg_df = sub_result_df.groupby('Topic')['Score'].sum()\n", " agg_df = agg_df.to_frame()\n", " agg_df.index.name = 'Topic'\n", " agg_df.columns = ['Total Score']\n", " agg_df = agg_df.assign(\n", " Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00\n", " )\n", " agg_df = agg_df.sort_values(by='Final_Score', ascending=False)\n", " title = \"Gaining & Avoidance Coverage\"\n", " display(HTML(f'

{title}

'))\n", " agg_df['Topic'] = agg_df.index\n", " rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]\n", " if len(rem_topics) > 0:\n", " rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})\n", " agg_df = pd.concat([agg_df, rem_agg_df])\n", " labels = agg_df['Final_Score'].round(1).astype('str') + '%'\n", " ax = agg_df.plot.bar(x='Topic', y='Final_Score', rot=0, figsize=(20, 5), align='center')\n", " for container in ax.containers:\n", " ax.bar_label(container, labels=labels)\n", " ax.yaxis.set_major_formatter(mtick.PercentFormatter())\n", " ax.legend([\"Final Score (%)\"])\n", " ax.set_xlabel('')\n", " plt.show()\n", " title = \"Gaining & Avoidance Scores\"\n", " display(HTML(f'

{title}

'))\n", " display_final_df(agg_df)\n", " if len(sub_2_result_df) > 0:\n", " sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)\n", " agrilla_df = sub_result_df.copy()\n", " else:\n", " print(query)\n", "\n", "def on_agr_button_next(b):\n", " global agrilla_df, annotated\n", " with func_outt:\n", " clear_output()\n", " if agrilla_df is not None:\n", " # convert the dataframe to the structure accepted by argilla\n", " converted_df = convert_df(agrilla_df)\n", " # convert pandas dataframe to DatasetForTextClassification\n", " dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)\n", " # delete the old DatasetForTextClassification from the Argilla web app if exists\n", " rg.delete(dataset_rg_name)\n", " # load the new DatasetForTextClassification into the Argilla web app\n", " rg.log(dataset_rg, name=dataset_rg_name)\n", " annotated = True\n", " else:\n", " display(Markdown(\"

Please score the answer first!

\"))\n", " \n", "def on_eval_button_next(b):\n", " global annotated\n", " with func_outt:\n", " clear_output()\n", " if annotated:\n", " data = dict(f1(dataset_rg_name))['data']\n", " display(custom_f1(data, \"Model Evaluation Results\"))\n", " else:\n", " display(Markdown(\"

Please score the answer and validate the data first!

\"))\n", "\n", "func_nlp_btn.on_click(on_func_button_next)\n", "gain_avoid_nlp_btn.on_click(on_gain_avoid_button_next)\n", "bhvr_agr_btn.on_click(on_agr_button_next)\n", "bhvr_eval_btn.on_click(on_eval_button_next)\n", "\n", "display(func_label, func_box)" ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3.9 (Argilla)", "language": "python", "name": "argilla" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "258.097px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }