{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - PBSP Scoring - Page 3" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import os\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "from qdrant_client import QdrantClient\n", "from qdrant_client.http.models import Distance, VectorParams, Batch, Filter, FieldCondition, Range, MatchValue\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "import nltk\n", "from nltk import sent_tokenize\n", "from sklearn.feature_extraction import text\n", "from pprint import pprint\n", "import re\n", "from flair.embeddings import TransformerDocumentEmbeddings\n", "from flair.data import Sentence\n", "from sentence_transformers import SentenceTransformer, util\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "from typing import Dict\n", "from setfit import SetFitModel\n", "from tqdm import tqdm\n", "import time\n", "for i in tqdm(range(30), disable=True):\n", " time.sleep(1)" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "bhvr_onto_file = 'ontology_page3_bhvr.csv'\n", "event_onto_file = 'ontology_page3_event.csv'\n", "embedding = TransformerDocumentEmbeddings('distilbert-base-uncased')\n", "client = QdrantClient(\n", " host=os.environ[\"QDRANT_API_URL\"], \n", " api_key=os.environ[\"QDRANT_API_KEY\"],\n", " timeout=60,\n", " port=443\n", ")\n", "collection_name = \"my_collection\"\n", "model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')\n", "vector_dim = 384 #{distilbert-base-uncased: 768, multi-qa-MiniLM-L6-cos-v1:384}\n", "sf_trig_model_name = \"setfit-zero-shot-classification-pbsp-p3-trig\"\n", "sf_trig_model = SetFitModel.from_pretrained(f\"aammari/{sf_trig_model_name}\")\n", "sf_cons_model_name = \"setfit-zero-shot-classification-pbsp-p3-cons\"\n", "sf_cons_model = SetFitModel.from_pretrained(f\"aammari/{sf_cons_model_name}\")\n", "\n", "# download nltk 'punkt' if not available\n", "try:\n", " nltk.data.find('tokenizers/punkt')\n", "except LookupError:\n", " nltk.download('punkt')\n", "\n", "# download nltk 'averaged_perceptron_tagger' if not available\n", "try:\n", " nltk.data.find('taggers/averaged_perceptron_tagger')\n", "except LookupError:\n", " nltk.download('averaged_perceptron_tagger')\n", " \n", "#argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "17fe501c", "metadata": { "hide_input": false }, "outputs": [], "source": [ "#Behaviour ontology to be loaded from a CSV file\n", "bhvr_onto_df = pd.read_csv(bhvr_onto_file, header=None).dropna()\n", "bhvr_onto_df.columns = ['text']\n", "bhvr_onto_lst = bhvr_onto_df['text'].tolist()" ] }, { "cell_type": "code", "execution_count": null, "id": "7fa6ce86", "metadata": {}, "outputs": [], "source": [ "#Setting event ontology to be loaded from a CSV file\n", "event_onto_df = pd.read_csv(event_onto_file, header=None).dropna()\n", "event_onto_df.columns = ['text']\n", "event_onto_lst = event_onto_df['text'].tolist()" ] }, { "cell_type": "code", "execution_count": null, "id": "72c2c6f9", "metadata": { "hide_input": false }, "outputs": [], "source": [ "#Text Preprocessing\n", "try:\n", " nlp = spacy.load('en_core_web_sm')\n", "except OSError:\n", " spacy.cli.download('en_core_web_sm')\n", " nlp = spacy.load('en_core_web_sm')\n", "sw_lst = text.ENGLISH_STOP_WORDS\n", "def preprocess(onto_lst):\n", " cleaned_onto_lst = []\n", " pattern = re.compile(r'^[a-z ]*$')\n", " for document in onto_lst:\n", " text = []\n", " doc = nlp(document)\n", " person_tokens = []\n", " for w in doc:\n", " if w.ent_type_ == 'PERSON':\n", " person_tokens.append(w.lemma_)\n", " for w in doc:\n", " if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:\n", " text.append(w.lemma_.lower())\n", " texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]\n", " cleaned_onto_lst.append(\" \".join(texts))\n", " return cleaned_onto_lst\n", "\n", "cl_bhvr_onto_lst = preprocess(bhvr_onto_lst)\n", "cl_event_onto_lst = preprocess(event_onto_lst)\n", "\n", "#pprint(cl_bhvr_onto_lst)\n", "#pprint(cl_event_onto_lst)" ] }, { "cell_type": "code", "execution_count": null, "id": "a1f934eb", "metadata": {}, "outputs": [], "source": [ "#compute document embeddings\n", "\n", "# distilbert-base-uncased from Flair\n", "def embeddings(cl_onto_lst):\n", " emb_onto_lst = []\n", " for doc in cl_onto_lst:\n", " sentence = Sentence(doc)\n", " embedding.embed(sentence)\n", " emb_onto_lst.append(sentence.embedding.tolist())\n", " return emb_onto_lst\n", "\n", "# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers\n", "def sentence_embeddings(cl_onto_lst):\n", " emb_onto_lst_temp = model.encode(cl_onto_lst)\n", " emb_onto_lst = [x.tolist() for x in emb_onto_lst_temp]\n", " return emb_onto_lst\n", "\n", "'''\n", "emb_bhvr_onto_lst = embeddings(cl_bhvr_onto_lst)\n", "emb_fh_onto_lst = embeddings(cl_fh_onto_lst)\n", "emb_rep_onto_lst = embeddings(cl_rep_onto_lst)\n", "'''\n", "\n", "emb_bhvr_onto_lst = sentence_embeddings(cl_bhvr_onto_lst)\n", "emb_event_onto_lst = sentence_embeddings(cl_event_onto_lst)" ] }, { "cell_type": "code", "execution_count": null, "id": "6302e312", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#add to qdrant collection\n", "def add_to_collection():\n", " global cl_bhvr_onto_lst, emb_bhvr_onto_lst, cl_event_onto_lst, emb_event_onto_lst\n", " client.recreate_collection(\n", " collection_name=collection_name,\n", " vectors_config=VectorParams(size=vector_dim, distance=Distance.COSINE),\n", " )\n", " doc_count = len(emb_bhvr_onto_lst) + len(emb_event_onto_lst)\n", " ids = list(range(1, doc_count+1))\n", " payloads = [{\"ontology\": \"behaviours\", \"phrase\": x} for x in cl_bhvr_onto_lst] + \\\n", " [{\"ontology\": \"setting_events\", \"phrase\": y} for y in cl_event_onto_lst]\n", " vectors = emb_bhvr_onto_lst+emb_event_onto_lst\n", " client.upsert(\n", " collection_name=f\"{collection_name}\",\n", " points=Batch(\n", " ids=ids,\n", " payloads=payloads,\n", " vectors=vectors\n", " ),\n", " )\n", "\n", "def count_collection():\n", " return len(client.scroll(\n", " collection_name=f\"{collection_name}\"\n", " )[0])\n", "\n", "add_to_collection()\n", "point_count = count_collection()" ] }, { "cell_type": "code", "execution_count": null, "id": "dff0ffdf", "metadata": {}, "outputs": [], "source": [ "#print(point_count)" ] }, { "cell_type": "code", "execution_count": null, "id": "b74861d4", "metadata": {}, "outputs": [], "source": [ "#query_filter=Filter(\n", "# must=[ \n", "# FieldCondition(\n", "# key='ontology',\n", "# match=MatchValue(value=\"setting_events\")# Condition based on values of `rand_number` field.\n", "# )\n", "# ]\n", "# )" ] }, { "cell_type": "code", "execution_count": null, "id": "40070313", "metadata": {}, "outputs": [], "source": [ "#noun phrase extraction\n", "def extract_noun_phrases(text):\n", " # Tokenize the text\n", " tokens = nltk.word_tokenize(text)\n", "\n", " # Part-of-speech tag the tokens\n", " tagged_tokens = nltk.pos_tag(tokens)\n", "\n", " # Define the noun phrase grammar\n", " grammar = r\"\"\"\n", " NP: {