# NDIS Project - PBSP Scoring - Page 3

In [None]:
import os
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display, clear_output, Javascript, HTML, Markdown
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, Batch, Filter, FieldCondition, Range, MatchValue
import json
import spacy
from spacy import displacy
import nltk
from nltk import sent_tokenize
from sklearn.feature_extraction import text
from pprint import pprint
import re
from flair.embeddings import TransformerDocumentEmbeddings
from flair.data import Sentence
from flair.models import TARSClassifier
from sentence_transformers import SentenceTransformer, util
import pandas as pd
import argilla as rg
from argilla.metrics.text_classification import f1
from typing import Dict
from setfit import SetFitModel
from tqdm import tqdm
import time
for i in tqdm(range(15), disable=True):
    time.sleep(1)

In [None]:
#initializations
embedding = TransformerDocumentEmbeddings('distilbert-base-uncased')
client = QdrantClient(
    host=os.environ["QDRANT_API_URL"], 
    api_key=os.environ["QDRANT_API_KEY"],
    timeout=60,
    port=443
)
collection_name = "my_collection"
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
vector_dim = 384 #{distilbert-base-uncased: 768, multi-qa-MiniLM-L6-cos-v1:384}
sf_func_model_name = "setfit-zero-shot-classification-pbsp-p3-func"
sf_func_model = SetFitModel.from_pretrained(f"aammari/{sf_func_model_name}")
tars_model_path = 'few-shot-model-gain-avoid'
tars = TARSClassifier().load(tars_model_path+'/best-model.pt')

# download nltk 'punkt' if not available
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# download nltk 'averaged_perceptron_tagger' if not available
try:
    nltk.data.find('taggers/averaged_perceptron_tagger')
except LookupError:
    nltk.download('averaged_perceptron_tagger')
    
#argilla
rg.init(
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"]
)

### <font color='red'>Domain Expert Section</font>
#### Enter the Topic Glossary

In [None]:
bhvr_onto_lst = [
    'hit employees',
    'push people',
    'throw objects',
    'beat students' 
]
bhvr_onto_text_input = widgets.Textarea(
    value='\n'.join(bhvr_onto_lst),
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '100%', 'width': '90%'}
)
bhvr_onto_label = widgets.Label(value='Behaviours')
bhvr_onto_box = widgets.VBox([bhvr_onto_label, bhvr_onto_text_input], 
                   layout={'width': '400px', 'height': '150px'})

In [None]:
fh_onto_lst = [
    'Gain the teacher attention',
    'Complete work in class',
    'Avoid difficult work'
]

fh_onto_text_input = widgets.Textarea(
    value='\n'.join(fh_onto_lst),
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '100%', 'width': '90%'}
)
fh_onto_label = widgets.Label(value='Functional Hypothesis')
fh_onto_box = widgets.VBox([fh_onto_label, fh_onto_text_input], 
                   layout={'width': '400px', 'height': '150px'})

In [None]:
rep_onto_lst = [
    'Ask teacher for help',
    'Replace full body slam',
    'Use a next sign'
]

rep_onto_text_input = widgets.Textarea(
    value='\n'.join(rep_onto_lst),
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '100%', 'width': '90%'}
)
rep_onto_label = widgets.Label(value='Replacement Behaviour')
rep_onto_box = widgets.VBox([rep_onto_label, rep_onto_text_input], 
                   layout={'width': '400px', 'height': '150px'})

#onto_boxes = widgets.HBox([bhvr_onto_box, fh_onto_box, rep_onto_box], 
#                   layout={'width': '90%', 'height': '150px'})

onto_boxes = widgets.HBox([fh_onto_box], 
                   layout={'width': '90%', 'height': '150px'})

display(onto_boxes)

In [None]:
#Text Preprocessing
try:
    nlp = spacy.load('en_core_web_sm')
except OSError:
    spacy.cli.download('en_core_web_sm')
    nlp = spacy.load('en_core_web_sm')
sw_lst = text.ENGLISH_STOP_WORDS
def preprocess(onto_lst):
    cleaned_onto_lst = []
    pattern = re.compile(r'^[a-z ]*$')
    for document in onto_lst:
        text = []
        doc = nlp(document)
        person_tokens = []
        for w in doc:
            if w.ent_type_ == 'PERSON':
                person_tokens.append(w.lemma_)
        for w in doc:
            if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:
                text.append(w.lemma_.lower())
        texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]
        cleaned_onto_lst.append(" ".join(texts))
    return cleaned_onto_lst

cl_bhvr_onto_lst = preprocess(bhvr_onto_lst)
cl_fh_onto_lst = preprocess(fh_onto_lst)
cl_rep_onto_lst = preprocess(rep_onto_lst)

#pprint(cl_bhvr_onto_lst)
#pprint(cl_fh_onto_lst)
#pprint(cl_rep_onto_lst)

In [None]:
#compute document embeddings

# distilbert-base-uncased from Flair
def embeddings(cl_onto_lst):
    emb_onto_lst = []
    for doc in cl_onto_lst:
        sentence = Sentence(doc)
        embedding.embed(sentence)
        emb_onto_lst.append(sentence.embedding.tolist())
    return emb_onto_lst

# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers
def sentence_embeddings(cl_onto_lst):
    emb_onto_lst_temp = model.encode(cl_onto_lst)
    emb_onto_lst = [x.tolist() for x in emb_onto_lst_temp]
    return emb_onto_lst

'''
emb_bhvr_onto_lst = embeddings(cl_bhvr_onto_lst)
emb_fh_onto_lst = embeddings(cl_fh_onto_lst)
emb_rep_onto_lst = embeddings(cl_rep_onto_lst)
'''

emb_bhvr_onto_lst = sentence_embeddings(cl_bhvr_onto_lst)
emb_fh_onto_lst = sentence_embeddings(cl_fh_onto_lst)
emb_rep_onto_lst = sentence_embeddings(cl_rep_onto_lst)

In [None]:
#add to qdrant collection
def add_to_collection():
    global cl_bhvr_onto_lst, emb_bhvr_onto_lst, cl_fh_onto_lst, emb_fh_onto_lst, cl_rep_onto_lst, emb_rep_onto_lst
    client.recreate_collection(
        collection_name=collection_name,
        vectors_config=VectorParams(size=vector_dim, distance=Distance.COSINE),
    )
    doc_count = len(emb_bhvr_onto_lst) + len(emb_fh_onto_lst) + len(emb_rep_onto_lst)
    ids = list(range(1, doc_count+1))
    payloads = [{"ontology": "behaviours", "phrase": x} for x in cl_bhvr_onto_lst] + \
               [{"ontology": "functional_hypothesis", "phrase": y} for y in cl_fh_onto_lst] + \
               [{"ontology": "replacement_behaviour", "phrase": z} for z in cl_rep_onto_lst]
    vectors = emb_bhvr_onto_lst+emb_fh_onto_lst+emb_rep_onto_lst
    client.upsert(
        collection_name=f"{collection_name}",
        points=Batch(
            ids=ids,
            payloads=payloads,
            vectors=vectors
        ),
    )

def count_collection():
    return len(client.scroll(
            collection_name=f"{collection_name}"
        )[0])

add_to_collection()
point_count = count_collection()
#print(point_count)

In [None]:
query_filter=Filter(
        must=[ 
            FieldCondition(
                key='ontology',
                match=MatchValue(value="functional_hypothesis")# Condition based on values of `rand_number` field.
            )
        ]
    )

In [None]:
#verb phrase extraction
def extract_vbs(data_chunked):
    for tup in data_chunked:
        if len(tup) > 2:
            yield(str(" ".join(str(x[0]) for x in tup)))

def get_verb_phrases(nltk_query):
    data_tok = nltk.word_tokenize(nltk_query) #tokenisation
    data_pos = nltk.pos_tag(data_tok) #POS tagging
    cfgs = [
        "CUSTOMCHUNK: {<VB><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<PRP><NNPS>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<PRP><NNPS>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<PRP><NNPS>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<PRP><NNPS>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<PRP><NNPS>}"
       ]
    vbs = []
    for cfg_1 in cfgs: 
        chunker = nltk.RegexpParser(cfg_1)
        data_chunked = chunker.parse(data_pos)
        vbs += extract_vbs(data_chunked)
    return vbs

In [None]:
#query and get score

# distilbert-base-uncased from Flair
def get_query_vector(query):
    sentence = Sentence(query)
    embedding.embed(sentence)
    query_vector = sentence.embedding.tolist()
    return query_vector

# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers
def sentence_get_query_vector(query):
    query_vector = model.encode(query)
    return query_vector

def search_collection(ontology, query_vector):
    query_filter=Filter(
        must=[  
            FieldCondition(
                key='ontology',
                match=MatchValue(value=ontology)
            )
        ]
    )
    
    hits = client.search(
        collection_name=f"{collection_name}",
        query_vector=query_vector,
        query_filter=query_filter, 
        append_payload=True,  
        limit=point_count 
    )
    return hits

semantic_passing_score = 0.50


#ontology = 'behaviours'
#query = 'punch father face'
#query_vector = sentence_get_query_vector(query)
#hist = search_collection(ontology, query_vector)

In [None]:
# format output
def color(df):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#ADD8E6')

def annotate_query(highlights, query):
    ents = []
    for h in highlights:
        ent_dict = {}
        for match in re.finditer(h, query):
            ent_dict = {"start": match.start(), "end": match.end(), "label": 'GLOSSARY'}
            break
        if len(ent_dict.keys()) > 0:
            ents.append(ent_dict)
    return ents

In [None]:
#setfit sentence extraction
def extract_sentences(nltk_query):
    sentences = sent_tokenize(nltk_query)
    return sentences

In [None]:
def convert_df(result_df):
    new_df = pd.DataFrame(columns=['text', 'prediction'])
    new_df['text'] = result_df['Phrase']
    new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], min(row['Score'], 1.0)]], axis=1)
    return new_df

In [None]:
def custom_f1(data: Dict[str, float], title: str):
    from plotly.subplots import make_subplots
    import plotly.colors
    import random

    fig = make_subplots(
        rows=2,
        cols=1,
        subplot_titles=[        "Overall Model Score",        "Model Score By Category",    ],
    )

    x = ['precision', 'recall', 'f1']
    macro_data = [v for k, v in data.items() if "macro" in k]
    fig.add_bar(
        x=x,
        y=macro_data,
        row=1,
        col=1,
    )
    per_label = {
        k: v
        for k, v in data.items()
        if all(key not in k for key in ["macro", "micro", "support"])
    }

    num_labels = int(len(per_label.keys())/3)
    fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]
    colors = random.sample(fixed_colors, num_labels)

    fig.add_bar(
        x=[k for k, v in per_label.items()],
        y=[v for k, v in per_label.items()],
        row=2,
        col=1,
        marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]
    )
    fig.update_layout(showlegend=False, title_text=title)

    return fig

In [None]:
def get_null_class_df(sentences, result_df):
    sents = result_df['Phrase'].tolist()
    null_sents = [x for x in sentences if x not in sents]
    topics = ['NO FUNCTION'] * len(null_sents)
    scores = [0.90] * len(null_sents)
    null_df = pd.DataFrame({'Phrase': null_sents, 'Topic': topics, 'Score': scores})
    return null_df

In [None]:
#setfit func query and get predicted topic

def get_sf_func_topic(sentences):
    preds = list(sf_func_model(sentences))
    return preds
def get_sf_func_topic_scores(sentences):
    preds = sf_func_model.predict_proba(sentences)
    preds = [max(list(x)) for x in preds]
    return preds

In [None]:
# setfit func format output
ind_func_topic_dict = {
        0: 'NO FUNCTION',
        1: 'FUNCTION',
    }

highlight_threshold = 0.25
passing_score = 0.50

def sf_func_color(df):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#CCFFCC')

def sf_annotate_query(highlights, query, topics):
    ents = []
    query = query.strip() # remove newline characters from the query string
    for h, t in zip(highlights, topics):
        h = re.escape(h)  # escape special characters in the highlights string
        ent_dict = {}
        for match in re.finditer(h, query):
            ent_dict = {"start": match.start(), "end": match.end(), "label": t}
            break
        if len(ent_dict.keys()) > 0:
            ents.append(ent_dict)
    return ents

In [None]:
#query and get predicted topic

p_classes = {'gain_attention': 0,
             'avoid_attention': 1,
             'unknown': 2
            }
def get_topic(sentences):
    preds = []
    for t in sentences:
        sentence = Sentence(t)
        tars.predict(sentence)
        try:
            pred = p_classes[sentence.tag]
        except:
            pred = 2
        preds.append(pred)
    return preds
def get_topic_scores(sentences):
    preds = []
    for t in sentences:
        sentence = Sentence(t)
        tars.predict(sentence)
        try:
            pred = sentence.score
        except:
            pred = 0.75
        preds.append(pred)
    return preds

In [None]:
# format output
ind_topic_dict = {
        0: 'GAIN-ATTENTION',
        1: 'AVOID-ATTENTION',
        2: 'UNKNOWN'
    }

topic_color_dict = {
        'GAIN-ATTENTION': '#FFCCCC',
        'AVOID-ATTENTION': '#CCFFFF'
    }

gain_avoid_passing_score = 0.25

def gain_avoid_color(df, color):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)

def gain_avoid_annotate_query(highlights, query, topics):
    ents = []
    for h, t in zip(highlights, topics):
        ent_dict = {}
        for match in re.finditer(h, query):
            ent_dict = {"start": match.start(), "end": match.end(), "label": t}
            break
        if len(ent_dict.keys()) > 0:
            ents.append(ent_dict)
    return ents

In [None]:
def path_to_image_html(path):
    return '<img src="'+ path + '" width="30" height="15" />'

final_passing = 0.0
def display_final_df(agg_df):
    tags = []
    crits = [
            'GAIN-ATTENTION',
            'AVOID-ATTENTION'
    ]
    orig_crits = crits
    crits = [x for x in crits if x in agg_df.index.tolist()]
    bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]
    paths = ['./tick_green.png' if x else './cross_red.png' for x in bools]
    df = pd.DataFrame({'Topic': crits, 'USED': paths})
    rem_crits = [x for x in orig_crits if x not in crits]
    if len(rem_crits) > 0:
        df2 = pd.DataFrame({'Topic': rem_crits, 'USED': ['./cross_red.png'] * len(rem_crits)})
        df = pd.concat([df, df2])
    df = df.set_index('Topic')
    pd.set_option('display.max_colwidth', None)
    display(HTML('<div style="text-align: center;">' + df.to_html(classes=["align-center"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '</div>'))

### <font color='red'>Practitioner Section</font>
#### Enter a summary statement outlining the <font color='blue'>functional hypothesis</font>

In [None]:
#demo with Voila

func_label = widgets.Label(value='Please type your answer:')
func_text_input = widgets.Textarea(
    value='',
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '300px', 'width': '90%'}
)

func_nlp_btn = widgets.Button(
    description='Score Functions',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Score Functions',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
gain_avoid_nlp_btn = widgets.Button(
    description='Detect Gain / Avoid',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Detect Gain / Avoid',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_agr_btn = widgets.Button(
    description='Validate Data',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Validate Data',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_eval_btn = widgets.Button(
    description='Evaluate Model',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Evaluate Model',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
btn_box = widgets.HBox([bhvr_agr_btn, bhvr_eval_btn], 
                       layout={'width': '100%', 'height': '160%'})
func_btn_box = widgets.HBox([func_nlp_btn, gain_avoid_nlp_btn], 
                       layout={'width': '100%', 'height': '160%'})
func_outt = widgets.Output()
func_outt.layout.height = '100%'
func_outt.layout.width = '100%'
func_box = widgets.VBox([func_text_input, func_btn_box, btn_box, func_outt], 
                   layout={'width': '100%', 'height': '160%'})
dataset_rg_name = 'pbsp-page3-func-argilla-ds'
agrilla_df = None
annotated = False
sub_2_result_dfs = []
def on_func_button_next(b):
    global fh_onto_lst, cl_fh_onto_lst, emb_fh_onto_lst, agrilla_df
    with func_outt:
        clear_output()
        fh_onto_lst = fh_onto_text_input.value.split("\n")
        cl_fh_onto_lst = preprocess(fh_onto_lst)
        orig_cl_dict = {x:y for x,y in zip(cl_fh_onto_lst, fh_onto_lst)}
        emb_fh_onto_lst = sentence_embeddings(cl_fh_onto_lst)
        add_to_collection()
        query = func_text_input.value
        vbs = get_verb_phrases(query)
        cl_vbs = preprocess(vbs)
        emb_vbs = sentence_embeddings(cl_vbs)
        vb_ind = -1
        highlights = []
        highlight_scores = []
        result_dfs = []
        for query_vector in emb_vbs:
            vb_ind += 1
            hist = search_collection('functional_hypothesis', query_vector)
            hist_dict = [dict(x) for x in hist]
            scores = [x['score'] for x in hist_dict]
            payloads = [orig_cl_dict[x['payload']['phrase']] for x in hist_dict]
            result_df = pd.DataFrame({'Score': scores, 'Glossary': payloads})
            result_df = result_df[result_df['Score'] >= semantic_passing_score]
            if len(result_df) > 0:
                highlights.append(vbs[vb_ind])
                highlight_scores.append(result_df.Score.max())
                result_df['Phrase'] = [vbs[vb_ind]] * len(result_df)
                result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
                result_dfs.append(result_df)
            else:
                continue
        ents = []
        colors = {}
        if len(highlights) > 0:
            ents = annotate_query(highlights, query)
            for ent in ents:
                colors[ent['label']] = '#ADD8E6'
        
        #setfit function
        sentences = extract_sentences(query)
        cl_sentences = preprocess(sentences)
        topic_inds = get_sf_func_topic(cl_sentences)
        topics = [ind_func_topic_dict[i] for i in topic_inds]
        scores = get_sf_func_topic_scores(cl_sentences)
        sf_func_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})
        sf_func_sub_result_df = sf_func_result_df[sf_func_result_df['Topic'] == 'FUNCTION']
        sub_2_result_df = sf_func_result_df[sf_func_result_df['Topic'] == 'NO FUNCTION']
        sub_2_result_df = pd.concat([sub_2_result_df, sf_func_sub_result_df]).reset_index(drop=True)
        sub_2_result_dfs.append(sub_2_result_df)
        sf_func_highlights = []
        sf_func_ents = []
        if len(sf_func_sub_result_df) > 0:
            sf_func_highlights = sf_func_sub_result_df['Phrase'].tolist()
            sf_func_highlight_topics = sf_func_sub_result_df['Topic'].tolist()
            sf_func_highlight_scores = sf_func_sub_result_df['Score'].tolist()    
            sf_func_ents = sf_annotate_query(sf_func_highlights, query, sf_func_highlight_topics)
            for ent, hs in zip(sf_func_ents, sf_func_highlight_scores):
                if hs >= passing_score:
                    colors[ent['label']] = '#CCFFCC'
                else:
                    colors[ent['label']] = '#FFCC66'
        options = {"ents": list(colors), "colors": colors}
        if len(sf_func_ents) > 0:
            ents = ents + sf_func_ents
            
        ex = [{"text": query,
               "ents": ents,
               "title": None}]
        if len(ents) > 0:
            title = "Answer Highlights"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            html = displacy.render(ex, style="ent", manual=True, options=options)
            display(HTML(html))
        if len(result_dfs) > 0:
            title = "Similar to Glossary"
            display(HTML(f'<center><h1 style="background-color: #ADD8E6; padding: 5px 10px;">{title}</h1></center>'))
            result_df = pd.concat(result_dfs).reset_index(drop = True)
            result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
            sub_2_result_df = result_df.copy()
            sub_2_result_df['Topic'] = ['FUNCTION'] * len(result_df)
            sub_2_result_df = sub_2_result_df[['Phrase', 'Topic', 'Score']].drop_duplicates().reset_index(drop=True)
            null_df = get_null_class_df(vbs, sub_2_result_df)
            if len(null_df) > 0:
                sub_2_result_df = pd.concat([sub_2_result_df, null_df]).reset_index(drop=True)
            sub_2_result_dfs.append(sub_2_result_df)
            agg_df = result_df.groupby(result_df.Phrase).max()
            agg_df['Phrase'] = agg_df.index
            agg_df = agg_df.reset_index(drop=True)
            agg_df = agg_df.drop(columns=['Glossary'])
            result_df = pd.merge(result_df, agg_df, 'inner', ['Phrase', 'Score'])
            result_df = result_df[['Phrase', 'Glossary', 'Score']]
            result_df = result_df.set_index('Phrase')
            display(color(result_df))
        if len(sf_func_sub_result_df) > 0:
            title = "Detected Functions"
            display(HTML(f'<center><h1 style="background-color: #CCFFCC; padding: 5px 10px;">{title}</h1></center>'))
            result_df = sf_func_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
            result_df = result_df.set_index('Phrase')
            display(sf_func_color(result_df))
        if len(sub_2_result_dfs) > 0:
            sub_2_result_df = pd.concat(sub_2_result_dfs).reset_index(drop=True)
            agrilla_df = sub_2_result_df.copy()

def on_gain_avoid_button_next(b):
    global agrilla_df
    with func_outt:
        clear_output()
        query = func_text_input.value
        sentences = extract_sentences(query)
        cl_sentences = preprocess(sentences)
        topic_inds = get_topic(cl_sentences)
        topics = [ind_topic_dict[i] for i in topic_inds]
        scores = get_topic_scores(cl_sentences)
        result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})
        sub_result_df = result_df[(result_df['Score'] >= gain_avoid_passing_score) & (result_df['Topic'] != 'UNKNOWN')]
        sub_2_result_df = result_df[result_df['Topic'] == 'UNKNOWN']
        highlights = []
        if len(sub_result_df) > 0:
            highlights = sub_result_df['Phrase'].tolist()
            highlight_topics = sub_result_df['Topic'].tolist()    
            ents = gain_avoid_annotate_query(highlights, query, highlight_topics)
            colors = {}
            for ent, ht in zip(ents, highlight_topics):
                colors[ent['label']] = topic_color_dict[ht]

            ex = [{"text": query,
                   "ents": ents,
                   "title": None}]
            title = "Gaining & Avoidance Highlights"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            html = displacy.render(ex, style="ent", manual=True, jupyter=True, options={'colors': colors})
            display(HTML(html))
            title = "Used Approach Classifications"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            for top in topic_color_dict.keys():
                top_result_df = sub_result_df[sub_result_df['Topic'] == top]
                if len(top_result_df) > 0:
                    top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
                    top_result_df = top_result_df.set_index('Phrase')
                    top_result_df = top_result_df[['Score']]
                    display(HTML(
                        f'<left><h2 style="text-decoration: underline; text-decoration-color:{topic_color_dict[top]};">{top}</h2></left>'))
                    display(gain_avoid_color(top_result_df, topic_color_dict[top]))
            
            agg_df = sub_result_df.groupby('Topic')['Score'].sum()
            agg_df = agg_df.to_frame()
            agg_df.index.name = 'Topic'
            agg_df.columns = ['Total Score']
            agg_df = agg_df.assign(
                Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00
            )
            agg_df = agg_df.sort_values(by='Final_Score', ascending=False)
            title = "Gaining & Avoidance Coverage"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            agg_df['Topic'] = agg_df.index
            rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]
            if len(rem_topics) > 0:
                rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})
                agg_df = pd.concat([agg_df, rem_agg_df])
            labels = agg_df['Final_Score'].round(1).astype('str') + '%'
            ax = agg_df.plot.bar(x='Topic', y='Final_Score', rot=0, figsize=(20, 5), align='center')
            for container in ax.containers:
                ax.bar_label(container, labels=labels)
                ax.yaxis.set_major_formatter(mtick.PercentFormatter())
                ax.legend(["Final Score (%)"])
                ax.set_xlabel('')
            plt.show()
            title = "Gaining & Avoidance Scores"
            display(HTML(f'<left><h1>{title}</h1></left>'))
            display_final_df(agg_df)
            if len(sub_2_result_df) > 0:
                sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)
            agrilla_df = sub_result_df.copy()
        else:
            print(query)

def on_agr_button_next(b):
    global agrilla_df, annotated
    with func_outt:
        clear_output()
        if agrilla_df is not None:
            # convert the dataframe to the structure accepted by argilla
            converted_df = convert_df(agrilla_df)
            # convert pandas dataframe to DatasetForTextClassification
            dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)
            # delete the old DatasetForTextClassification from the Argilla web app if exists
            rg.delete(dataset_rg_name)
            # load the new DatasetForTextClassification into the Argilla web app
            rg.log(dataset_rg, name=dataset_rg_name)
            annotated = True
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer first!</h2>"))
            
def on_eval_button_next(b):
    global annotated
    with func_outt:
        clear_output()
        if annotated:
            data = dict(f1(dataset_rg_name))['data']
            display(custom_f1(data, "Model Evaluation Results"))
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer and validate the data first!</h2>"))

func_nlp_btn.on_click(on_func_button_next)
gain_avoid_nlp_btn.on_click(on_gain_avoid_button_next)
bhvr_agr_btn.on_click(on_agr_button_next)
bhvr_eval_btn.on_click(on_eval_button_next)

display(func_label, func_box)