}\"\n",
" ]\n",
" vbs = []\n",
" for cfg_1 in cfgs: \n",
" chunker = nltk.RegexpParser(cfg_1)\n",
" data_chunked = chunker.parse(data_pos)\n",
" vbs += extract_vbs(data_chunked)\n",
" return vbs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1550437b",
"metadata": {},
"outputs": [],
"source": [
"#query and get score\n",
"\n",
"# distilbert-base-uncased from Flair\n",
"def get_query_vector(query):\n",
" sentence = Sentence(query)\n",
" embedding.embed(sentence)\n",
" query_vector = sentence.embedding.tolist()\n",
" return query_vector\n",
"\n",
"# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers\n",
"def sentence_get_query_vector(query):\n",
" query_vector = model.encode(query)\n",
" return query_vector\n",
"\n",
"def search_collection(ontology, query_vector):\n",
" query_filter=Filter(\n",
" must=[ \n",
" FieldCondition(\n",
" key='ontology',\n",
" match=MatchValue(value=ontology)\n",
" )\n",
" ]\n",
" )\n",
" \n",
" hits = client.search(\n",
" collection_name=f\"{collection_name}\",\n",
" query_vector=query_vector,\n",
" query_filter=query_filter, \n",
" append_payload=True, \n",
" limit=point_count \n",
" )\n",
" return hits\n",
"\n",
"semantic_passing_score = 0.50\n",
"\n",
"\n",
"#ontology = 'behaviours'\n",
"#query = 'punch father face'\n",
"#query_vector = sentence_get_query_vector(query)\n",
"#hist = search_collection(ontology, query_vector)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02fda761",
"metadata": {},
"outputs": [],
"source": [
"# format output\n",
"def color(df):\n",
" return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#ADD8E6')\n",
"\n",
"def annotate_query(highlights, query):\n",
" ents = []\n",
" for h in highlights:\n",
" ent_dict = {}\n",
" for match in re.finditer(h, query):\n",
" ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": 'GLOSSARY'}\n",
" break\n",
" if len(ent_dict.keys()) > 0:\n",
" ents.append(ent_dict)\n",
" return ents"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79b519a6",
"metadata": {},
"outputs": [],
"source": [
"#setfit sentence extraction\n",
"def extract_sentences(nltk_query):\n",
" sentences = sent_tokenize(nltk_query)\n",
" return sentences"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38594d91",
"metadata": {},
"outputs": [],
"source": [
"def convert_df(result_df):\n",
" new_df = pd.DataFrame(columns=['text', 'prediction'])\n",
" new_df['text'] = result_df['Phrase']\n",
" new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], min(row['Score'], 1.0)]], axis=1)\n",
" return new_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4542692f",
"metadata": {},
"outputs": [],
"source": [
"def custom_f1(data: Dict[str, float], title: str):\n",
" from plotly.subplots import make_subplots\n",
" import plotly.colors\n",
" import random\n",
"\n",
" fig = make_subplots(\n",
" rows=2,\n",
" cols=1,\n",
" subplot_titles=[ \"Overall Model Score\", \"Model Score By Category\", ],\n",
" )\n",
"\n",
" x = ['precision', 'recall', 'f1']\n",
" macro_data = [v for k, v in data.items() if \"macro\" in k]\n",
" fig.add_bar(\n",
" x=x,\n",
" y=macro_data,\n",
" row=1,\n",
" col=1,\n",
" )\n",
" per_label = {\n",
" k: v\n",
" for k, v in data.items()\n",
" if all(key not in k for key in [\"macro\", \"micro\", \"support\"])\n",
" }\n",
"\n",
" num_labels = int(len(per_label.keys())/3)\n",
" fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]\n",
" colors = random.sample(fixed_colors, num_labels)\n",
"\n",
" fig.add_bar(\n",
" x=[k for k, v in per_label.items()],\n",
" y=[v for k, v in per_label.items()],\n",
" row=2,\n",
" col=1,\n",
" marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]\n",
" )\n",
" fig.update_layout(showlegend=False, title_text=title)\n",
"\n",
" return fig"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07e3ff7b",
"metadata": {},
"outputs": [],
"source": [
"def get_null_class_df(sentences, result_df):\n",
" sents = result_df['Phrase'].tolist()\n",
" null_sents = [x for x in sentences if x not in sents]\n",
" topics = ['NO FUNCTION'] * len(null_sents)\n",
" scores = [0.90] * len(null_sents)\n",
" null_df = pd.DataFrame({'Phrase': null_sents, 'Topic': topics, 'Score': scores})\n",
" return null_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d5f8d3c",
"metadata": {},
"outputs": [],
"source": [
"#setfit func query and get predicted topic\n",
"\n",
"def get_sf_func_topic(sentences):\n",
" preds = list(sf_func_model(sentences))\n",
" return preds\n",
"def get_sf_func_topic_scores(sentences):\n",
" preds = sf_func_model.predict_proba(sentences)\n",
" preds = [max(list(x)) for x in preds]\n",
" return preds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67bf2154",
"metadata": {},
"outputs": [],
"source": [
"# setfit func format output\n",
"ind_func_topic_dict = {\n",
" 0: 'NO FUNCTION',\n",
" 1: 'FUNCTION',\n",
" }\n",
"\n",
"highlight_threshold = 0.25\n",
"passing_score = 0.50\n",
"\n",
"def sf_func_color(df):\n",
" return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#CCFFCC')\n",
"\n",
"def sf_annotate_query(highlights, query, topics):\n",
" ents = []\n",
" query = query.strip() # remove newline characters from the query string\n",
" for h, t in zip(highlights, topics):\n",
" h = re.escape(h) # escape special characters in the highlights string\n",
" ent_dict = {}\n",
" for match in re.finditer(h, query):\n",
" ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n",
" break\n",
" if len(ent_dict.keys()) > 0:\n",
" ents.append(ent_dict)\n",
" return ents"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "316bd9e2",
"metadata": {},
"outputs": [],
"source": [
"#query and get predicted topic\n",
"\n",
"p_classes = {'avoid_activities': 0,\n",
" 'avoid_others': 1,\n",
" 'avoid_situations': 2,\n",
" 'avoid_stimuli': 3,\n",
" 'unknown': 4\n",
" }\n",
"def get_topic(sentences):\n",
" preds = []\n",
" for t in sentences:\n",
" sentence = Sentence(t)\n",
" tars.predict(sentence)\n",
" try:\n",
" pred = p_classes[sentence.tag]\n",
" except:\n",
" pred = 4\n",
" preds.append(pred)\n",
" return preds\n",
"def get_topic_scores(sentences):\n",
" preds = []\n",
" for t in sentences:\n",
" sentence = Sentence(t)\n",
" tars.predict(sentence)\n",
" try:\n",
" pred = sentence.score\n",
" except:\n",
" pred = 0.75\n",
" preds.append(pred)\n",
" return preds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1cadde44",
"metadata": {},
"outputs": [],
"source": [
"# format output\n",
"ind_topic_dict = {\n",
" 0: 'AVOID-ACTIVITIES',\n",
" 1: 'AVOID-OTHERS',\n",
" 2: 'AVOID-SITUATIONS',\n",
" 3: 'AVOID-STIMULI',\n",
" 4: 'UNKNOWN'\n",
" }\n",
"\n",
"topic_color_dict = {\n",
" 'AVOID-ACTIVITIES': '#FFE5E5', # lighter shade of red\n",
" 'AVOID-OTHERS': '#E5FFFF', # lighter shade of blue-green\n",
" 'AVOID-SITUATIONS': '#FFFFE5', # lighter shade of yellow\n",
" 'AVOID-STIMULI': '#FFCCE5' # lighter shade of pink\n",
"}\n",
"\n",
"gain_avoid_passing_score = 0.25\n",
"\n",
"def gain_avoid_color(df, color):\n",
" return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)\n",
"\n",
"def gain_avoid_annotate_query(highlights, query, topics):\n",
" ents = []\n",
" for h, t in zip(highlights, topics):\n",
" ent_dict = {}\n",
" for match in re.finditer(h, query):\n",
" ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n",
" break\n",
" if len(ent_dict.keys()) > 0:\n",
" ents.append(ent_dict)\n",
" return ents"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51dbd744",
"metadata": {},
"outputs": [],
"source": [
"def path_to_image_html(path):\n",
" return ''\n",
"\n",
"final_passing = 0.0\n",
"def display_final_df(agg_df):\n",
" tags = []\n",
" crits = [\n",
" 'AVOID-ACTIVITIES',\n",
" 'AVOID-OTHERS',\n",
" 'AVOID-SITUATIONS',\n",
" 'AVOID-STIMULI'\n",
" ]\n",
" orig_crits = crits\n",
" crits = [x for x in crits if x in agg_df.index.tolist()]\n",
" bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]\n",
" paths = ['./tick_green.png' if x else './cross_red.png' for x in bools]\n",
" df = pd.DataFrame({'Topic': crits, 'USED': paths})\n",
" rem_crits = [x for x in orig_crits if x not in crits]\n",
" if len(rem_crits) > 0:\n",
" df2 = pd.DataFrame({'Topic': rem_crits, 'USED': ['./cross_red.png'] * len(rem_crits)})\n",
" df = pd.concat([df, df2])\n",
" df = df.set_index('Topic')\n",
" pd.set_option('display.max_colwidth', None)\n",
" display(HTML('' + df.to_html(classes=[\"align-center\"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '
'))"
]
},
{
"cell_type": "markdown",
"id": "2c6e9fe7",
"metadata": {},
"source": [
"### Practitioner Section\n",
"#### Enter a summary statement outlining the functional hypothesis"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76dd8cab",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"#demo with Voila\n",
"\n",
"func_label = widgets.Label(value='Please type your answer:')\n",
"func_text_input = widgets.Textarea(\n",
" value='',\n",
" placeholder='Type your answer',\n",
" description='',\n",
" disabled=False,\n",
" layout={'height': '300px', 'width': '90%'}\n",
")\n",
"\n",
"func_nlp_btn = widgets.Button(\n",
" description='Score Functions',\n",
" disabled=False,\n",
" button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n",
" tooltip='Score Functions',\n",
" icon='check',\n",
" layout={'height': '70px', 'width': '250px'}\n",
")\n",
"gain_avoid_nlp_btn = widgets.Button(\n",
" description='Detect Gain / Avoid',\n",
" disabled=False,\n",
" button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n",
" tooltip='Detect Gain / Avoid',\n",
" icon='check',\n",
" layout={'height': '70px', 'width': '250px'}\n",
")\n",
"bhvr_agr_btn = widgets.Button(\n",
" description='Validate Data',\n",
" disabled=False,\n",
" button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n",
" tooltip='Validate Data',\n",
" icon='check',\n",
" layout={'height': '70px', 'width': '250px'}\n",
")\n",
"bhvr_eval_btn = widgets.Button(\n",
" description='Evaluate Model',\n",
" disabled=False,\n",
" button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n",
" tooltip='Evaluate Model',\n",
" icon='check',\n",
" layout={'height': '70px', 'width': '250px'}\n",
")\n",
"btn_box = widgets.HBox([bhvr_agr_btn, bhvr_eval_btn], \n",
" layout={'width': '100%', 'height': '160%'})\n",
"func_btn_box = widgets.HBox([func_nlp_btn, gain_avoid_nlp_btn], \n",
" layout={'width': '100%', 'height': '160%'})\n",
"func_outt = widgets.Output()\n",
"func_outt.layout.height = '100%'\n",
"func_outt.layout.width = '100%'\n",
"func_box = widgets.VBox([func_text_input, func_btn_box, btn_box, func_outt], \n",
" layout={'width': '100%', 'height': '160%'})\n",
"dataset_rg_name = 'pbsp-page3-func-argilla-ds'\n",
"agrilla_df = None\n",
"annotated = False\n",
"sub_2_result_dfs = []\n",
"def on_func_button_next(b):\n",
" global fh_onto_lst, cl_fh_onto_lst, emb_fh_onto_lst, agrilla_df\n",
" with func_outt:\n",
" clear_output()\n",
" fh_onto_lst = fh_onto_text_input.value.split(\"\\n\")\n",
" cl_fh_onto_lst = preprocess(fh_onto_lst)\n",
" orig_cl_dict = {x:y for x,y in zip(cl_fh_onto_lst, fh_onto_lst)}\n",
" emb_fh_onto_lst = sentence_embeddings(cl_fh_onto_lst)\n",
" add_to_collection()\n",
" query = func_text_input.value\n",
" vbs = get_verb_phrases(query)\n",
" cl_vbs = preprocess(vbs)\n",
" emb_vbs = sentence_embeddings(cl_vbs)\n",
" vb_ind = -1\n",
" highlights = []\n",
" highlight_scores = []\n",
" result_dfs = []\n",
" for query_vector in emb_vbs:\n",
" vb_ind += 1\n",
" hist = search_collection('functional_hypothesis', query_vector)\n",
" hist_dict = [dict(x) for x in hist]\n",
" scores = [x['score'] for x in hist_dict]\n",
" payloads = [orig_cl_dict[x['payload']['phrase']] for x in hist_dict]\n",
" result_df = pd.DataFrame({'Score': scores, 'Glossary': payloads})\n",
" result_df = result_df[result_df['Score'] >= semantic_passing_score]\n",
" if len(result_df) > 0:\n",
" highlights.append(vbs[vb_ind])\n",
" highlight_scores.append(result_df.Score.max())\n",
" result_df['Phrase'] = [vbs[vb_ind]] * len(result_df)\n",
" result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n",
" result_dfs.append(result_df)\n",
" else:\n",
" continue\n",
" ents = []\n",
" colors = {}\n",
" if len(highlights) > 0:\n",
" ents = annotate_query(highlights, query)\n",
" for ent in ents:\n",
" colors[ent['label']] = '#ADD8E6'\n",
" \n",
" #setfit function\n",
" sentences = extract_sentences(query)\n",
" cl_sentences = preprocess(sentences)\n",
" topic_inds = get_sf_func_topic(cl_sentences)\n",
" topics = [ind_func_topic_dict[i] for i in topic_inds]\n",
" scores = get_sf_func_topic_scores(cl_sentences)\n",
" sf_func_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n",
" sf_func_sub_result_df = sf_func_result_df[sf_func_result_df['Topic'] == 'FUNCTION']\n",
" sub_2_result_df = sf_func_result_df[sf_func_result_df['Topic'] == 'NO FUNCTION']\n",
" sub_2_result_df = pd.concat([sub_2_result_df, sf_func_sub_result_df]).reset_index(drop=True)\n",
" sub_2_result_dfs.append(sub_2_result_df)\n",
" sf_func_highlights = []\n",
" sf_func_ents = []\n",
" if len(sf_func_sub_result_df) > 0:\n",
" sf_func_highlights = sf_func_sub_result_df['Phrase'].tolist()\n",
" sf_func_highlight_topics = sf_func_sub_result_df['Topic'].tolist()\n",
" sf_func_highlight_scores = sf_func_sub_result_df['Score'].tolist() \n",
" sf_func_ents = sf_annotate_query(sf_func_highlights, query, sf_func_highlight_topics)\n",
" for ent, hs in zip(sf_func_ents, sf_func_highlight_scores):\n",
" if hs >= passing_score:\n",
" colors[ent['label']] = '#CCFFCC'\n",
" else:\n",
" colors[ent['label']] = '#FFCC66'\n",
" options = {\"ents\": list(colors), \"colors\": colors}\n",
" if len(sf_func_ents) > 0:\n",
" ents = ents + sf_func_ents\n",
" \n",
" ex = [{\"text\": query,\n",
" \"ents\": ents,\n",
" \"title\": None}]\n",
" if len(ents) > 0:\n",
" title = \"Answer Highlights\"\n",
" display(HTML(f'{title}
'))\n",
" html = displacy.render(ex, style=\"ent\", manual=True, options=options)\n",
" display(HTML(html))\n",
" if len(result_dfs) > 0:\n",
" title = \"Similar to Glossary\"\n",
" display(HTML(f'{title}
'))\n",
" result_df = pd.concat(result_dfs).reset_index(drop = True)\n",
" result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n",
" sub_2_result_df = result_df.copy()\n",
" sub_2_result_df['Topic'] = ['FUNCTION'] * len(result_df)\n",
" sub_2_result_df = sub_2_result_df[['Phrase', 'Topic', 'Score']].drop_duplicates().reset_index(drop=True)\n",
" null_df = get_null_class_df(vbs, sub_2_result_df)\n",
" if len(null_df) > 0:\n",
" sub_2_result_df = pd.concat([sub_2_result_df, null_df]).reset_index(drop=True)\n",
" sub_2_result_dfs.append(sub_2_result_df)\n",
" agg_df = result_df.groupby(result_df.Phrase).max()\n",
" agg_df['Phrase'] = agg_df.index\n",
" agg_df = agg_df.reset_index(drop=True)\n",
" agg_df = agg_df.drop(columns=['Glossary'])\n",
" result_df = pd.merge(result_df, agg_df, 'inner', ['Phrase', 'Score'])\n",
" result_df = result_df[['Phrase', 'Glossary', 'Score']]\n",
" result_df = result_df.set_index('Phrase')\n",
" display(color(result_df))\n",
" if len(sf_func_sub_result_df) > 0:\n",
" title = \"Detected Functions\"\n",
" display(HTML(f'{title}
'))\n",
" result_df = sf_func_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n",
" result_df = result_df.set_index('Phrase')\n",
" display(sf_func_color(result_df))\n",
" if len(sub_2_result_dfs) > 0:\n",
" sub_2_result_df = pd.concat(sub_2_result_dfs).reset_index(drop=True)\n",
" agrilla_df = sub_2_result_df.copy()\n",
"\n",
"def on_gain_avoid_button_next(b):\n",
" global agrilla_df\n",
" with func_outt:\n",
" clear_output()\n",
" query = func_text_input.value\n",
" sentences = extract_sentences(query)\n",
" cl_sentences = preprocess(sentences)\n",
" topic_inds = get_topic(cl_sentences)\n",
" topics = [ind_topic_dict[i] for i in topic_inds]\n",
" scores = get_topic_scores(cl_sentences)\n",
" result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n",
" sub_result_df = result_df[(result_df['Score'] >= gain_avoid_passing_score) & (result_df['Topic'] != 'UNKNOWN')]\n",
" sub_2_result_df = result_df[result_df['Topic'] == 'UNKNOWN']\n",
" highlights = []\n",
" if len(sub_result_df) > 0:\n",
" highlights = sub_result_df['Phrase'].tolist()\n",
" highlight_topics = sub_result_df['Topic'].tolist() \n",
" ents = gain_avoid_annotate_query(highlights, query, highlight_topics)\n",
" colors = {}\n",
" for ent, ht in zip(ents, highlight_topics):\n",
" colors[ent['label']] = topic_color_dict[ht]\n",
"\n",
" ex = [{\"text\": query,\n",
" \"ents\": ents,\n",
" \"title\": None}]\n",
" title = \"Gaining & Avoidance Highlights\"\n",
" display(HTML(f'{title}
'))\n",
" html = displacy.render(ex, style=\"ent\", manual=True, jupyter=True, options={'colors': colors})\n",
" display(HTML(html))\n",
" title = \"Gaining & Avoidance Classifications\"\n",
" display(HTML(f'{title}
'))\n",
" for top in topic_color_dict.keys():\n",
" top_result_df = sub_result_df[sub_result_df['Topic'] == top]\n",
" if len(top_result_df) > 0:\n",
" top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n",
" top_result_df = top_result_df.set_index('Phrase')\n",
" top_result_df = top_result_df[['Score']]\n",
" display(HTML(\n",
" f'{top}
'))\n",
" display(gain_avoid_color(top_result_df, topic_color_dict[top]))\n",
" \n",
" agg_df = sub_result_df.groupby('Topic')['Score'].sum()\n",
" agg_df = agg_df.to_frame()\n",
" agg_df.index.name = 'Topic'\n",
" agg_df.columns = ['Total Score']\n",
" agg_df = agg_df.assign(\n",
" Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00\n",
" )\n",
" agg_df = agg_df.sort_values(by='Final_Score', ascending=False)\n",
" title = \"Gaining & Avoidance Coverage\"\n",
" display(HTML(f'{title}
'))\n",
" agg_df['Topic'] = agg_df.index\n",
" rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]\n",
" if len(rem_topics) > 0:\n",
" rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})\n",
" agg_df = pd.concat([agg_df, rem_agg_df])\n",
" labels = agg_df['Final_Score'].round(1).astype('str') + '%'\n",
" ax = agg_df.plot.bar(x='Topic', y='Final_Score', rot=0, figsize=(20, 5), align='center')\n",
" for container in ax.containers:\n",
" ax.bar_label(container, labels=labels)\n",
" ax.yaxis.set_major_formatter(mtick.PercentFormatter())\n",
" ax.legend([\"Final Score (%)\"])\n",
" ax.set_xlabel('')\n",
" plt.show()\n",
" title = \"Gaining & Avoidance Scores\"\n",
" display(HTML(f'{title}
'))\n",
" display_final_df(agg_df)\n",
" if len(sub_2_result_df) > 0:\n",
" sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)\n",
" agrilla_df = sub_result_df.copy()\n",
" else:\n",
" print(query)\n",
"\n",
"def on_agr_button_next(b):\n",
" global agrilla_df, annotated\n",
" with func_outt:\n",
" clear_output()\n",
" if agrilla_df is not None:\n",
" # convert the dataframe to the structure accepted by argilla\n",
" converted_df = convert_df(agrilla_df)\n",
" # convert pandas dataframe to DatasetForTextClassification\n",
" dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)\n",
" # delete the old DatasetForTextClassification from the Argilla web app if exists\n",
" rg.delete(dataset_rg_name, workspace=\"admin\")\n",
" # load the new DatasetForTextClassification into the Argilla web app\n",
" rg.log(dataset_rg, name=dataset_rg_name, workspace=\"admin\")\n",
" # Make sure all classes are present for annotation\n",
" if 'FUNCTION' in agrilla_df['Topic'].unique().tolist():\n",
" rg_settings = rg.TextClassificationSettings(label_schema=['FUNCTION', \n",
" 'GLOSSARY', \n",
" 'NONE'\n",
" ])\n",
" else:\n",
" rg_settings = rg.TextClassificationSettings(label_schema=['AVOID-ACTIVITIES',\n",
" 'AVOID-OTHERS',\n",
" 'AVOID-SITUATIONS',\n",
" 'AVOID-STIMULI',\n",
" 'UNKNOWN'\n",
" ])\n",
" rg.configure_dataset(name=dataset_rg_name, workspace=\"admin\", settings=rg_settings)\n",
" annotated = True\n",
" else:\n",
" display(Markdown(\"Please score the answer first!
\"))\n",
" \n",
"def on_eval_button_next(b):\n",
" global annotated\n",
" with func_outt:\n",
" clear_output()\n",
" if annotated:\n",
" data = dict(f1(dataset_rg_name))['data']\n",
" display(custom_f1(data, \"Model Evaluation Results\"))\n",
" else:\n",
" display(Markdown(\"Please score the answer and validate the data first!
\"))\n",
"\n",
"func_nlp_btn.on_click(on_func_button_next)\n",
"gain_avoid_nlp_btn.on_click(on_gain_avoid_button_next)\n",
"bhvr_agr_btn.on_click(on_agr_button_next)\n",
"bhvr_eval_btn.on_click(on_eval_button_next)\n",
"\n",
"display(func_label, func_box)"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3.9 (Argilla)",
"language": "python",
"name": "argilla"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "258.097px"
},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}