import spacy |
import streamlit as st |
import pandas as pd |
from PyPDF2 import PdfReader |
from io import StringIO |
import json |
import warnings |
import os |
import ast |
@st.cache(show_spinner=False, allow_output_mutation=True, suppress_st_warning=True) |
def load_models(model_names: list, args: dict, model_names_dir: list)-> dict: |
""" |
Check if model name refers to fine tuned models that are located in the model_dir or |
default models native to spacy. Load them according to required methods |
Parameters: |
model_names: list of model names for inference |
args: dict, configuration parameters |
model_names_dir: list of model that are from the model_names_dir which are fine tuned models |
Returns: |
model_dict: A dictionary of keys representing the model names and values containing the model. |
""" |
assert (model_names is not None) or (len(model_names)!=0), "No models avaliable" |
model_dict = {} |
for model_name in model_names: |
print(model_name) |
if model_name in model_names_dir: |
try: |
model_path = os.path.join(args['model_dir'], model_name) |
model = spacy.load(model_path) |
except: |
warnings.warn(f"Path to {model_name} not found") |
else: |
try: |
model = spacy.load(model_name) |
except: |
warnings.warn(f'Model: {model_name} not found') |
model_dict.update({model_name:model}) |
print('Model loaded') |
return model_dict |
def process_text(doc: spacy, selected_entities: list,colors: list)-> list: |
""" |
This function is to process the tokens from the doc type output from spacy models such that tokens that |
are grouped together by their corresponding entities. This allow the st-annotations to be processed |
the tokens for visualization |
Example: "Hi John, i am sick with cough and flu" |
Entities: person , disease |
Output: [(Hi)(John, 'person', blue)(i am sick)(cough, 'disease', red)(and)(flu, 'disease', red)] |
Parameters: |
doc : spacy document |
selected_entities : list of entities |
colors : list of colors |
Returns: |
tokens: list of tuples |
""" |
tokens = [] |
span = '' |
p_ent = None |
last = len(doc) |
for no, token in enumerate(doc): |
add_span = False |
for ent in selected_entities: |
if (token.ent_type_ == ent) & (ent in selected_entities): |
span += token.text + " " |
p_ent = ent |
add_span = True |
if no+1 == last: |
tokens.append((span, ent, colors[ent],'#464646')) |
if (add_span is False) & (len(span) >1): |
tokens.append((span, p_ent, colors[p_ent],'#464646')) |
span = '' |
p_ent = None |
if add_span is False: |
tokens.append(" " + token.text + " ") |
return tokens |
def process_text_compare(infer_input: dict, selected_entities: list, colors: list)-> list: |
""" |
This function is use when user is looking to compare the text annotations between the prediction and |
labels. This function is to process the tokens from evaluation data such that tokens that |
are grouped together by their corresponding entities. This allow the st-annotations to be processed |
the tokens for visualization |
Example: "Hi John, i am sick with cough and flu" |
Entities: person , disease |
Output: [(Hi)(John, 'person', blue)(i am sick)(cough, 'disease', red)(and)(flu, 'disease', red)] |
Parameters: |
infer_input : spacy document |
selected_entities : list of entities |
colors : list of colors |
Returns: |
tokens: list of tuples |
""" |
tokens = [] |
start_=0 |
end_= len(infer_input['text']) |
for start, end, entities in infer_input['entities']: |
if entities in selected_entities: |
span = infer_input['text'][start:end+1] |
if start_ != start: |
b4_span = infer_input['text'][start_:start] |
tokens.append(" " + b4_span + " ") |
tokens.append((span, entities, colors[entities],'#464646')) |
start_=end |
if start_ <= end_: |
span = infer_input['text'][start_:end_+1] |
tokens.append(" " + span + " ") |
return tokens |
def process_files(uploaded_file, text_input): |
""" |
As the app allows uploading files of mutiple files types, at present |
such as json, csv, pdf and txt format. |
The function is to detect what kind of file has been uploaded and process |
the files accordingly. |
If file has been uplaoded it will replace existing text_input |
Parameters: |
uploaded_file: The UploadedFile class is a subclass of BytesIO, and therefore it is "file-like". |
text_input: str / dict /list |
Return: |
text_input: list / dict / str |
""" |
if uploaded_file is not None: |
if uploaded_file.name[-3:]=='csv': |
text_input = pd.read_csv(uploaded_file, converters={'entities': ast.literal_eval}) |
text_input = text_input.to_dict('records') |
elif uploaded_file.name[-3:]=='son': |
text_input = json.load(uploaded_file) |
else: |
try: |
text_input = "" |
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) |
for line in stringio.readlines(): |
text_input += line + "\n" |
except: |
text_input = [] |
reader = PdfReader(uploaded_file) |
count = len(reader.pages) |
for i in range(count): |
pages = reader.pages[i] |
text_input.append(pages.extract_text()) |
text_input = ''.join(text_input) |
return text_input |