ner_pg / src /app_utils.py
Kaelan
Add application file
a197a13
raw
history blame
6.3 kB
import spacy
import streamlit as st
import pandas as pd
from PyPDF2 import PdfReader
from io import StringIO
import json
import warnings
import os
import ast
@st.cache(show_spinner=False, allow_output_mutation=True, suppress_st_warning=True)
#@st.cache_resource
def load_models(model_names: list, args: dict, model_names_dir: list)-> dict:
"""
Check if model name refers to fine tuned models that are located in the model_dir or
default models native to spacy. Load them according to required methods
Parameters:
model_names: list of model names for inference
args: dict, configuration parameters
model_names_dir: list of model that are from the model_names_dir which are fine tuned models
Returns:
model_dict: A dictionary of keys representing the model names and values containing the model.
"""
assert (model_names is not None) or (len(model_names)!=0), "No models avaliable"
model_dict = {}
for model_name in model_names:
print(model_name)
# loading model from directory
if model_name in model_names_dir:
try:
model_path = os.path.join(args['model_dir'], model_name)
model = spacy.load(model_path)
except:
warnings.warn(f"Path to {model_name} not found")
else:
try:
#load default models from spacy
model = spacy.load(model_name)
except:
warnings.warn(f'Model: {model_name} not found')
model_dict.update({model_name:model})
print('Model loaded')
return model_dict
def process_text(doc: spacy, selected_entities: list,colors: list)-> list:
"""
This function is to process the tokens from the doc type output from spacy models such that tokens that
are grouped together by their corresponding entities. This allow the st-annotations to be processed
the tokens for visualization
Example: "Hi John, i am sick with cough and flu"
Entities: person , disease
Output: [(Hi)(John, 'person', blue)(i am sick)(cough, 'disease', red)(and)(flu, 'disease', red)]
Parameters:
doc : spacy document
selected_entities : list of entities
colors : list of colors
Returns:
tokens: list of tuples
"""
tokens = []
span = ''
p_ent = None
last = len(doc)
for no, token in enumerate(doc):
add_span = False
for ent in selected_entities:
if (token.ent_type_ == ent) & (ent in selected_entities):
span += token.text + " "
p_ent = ent
add_span = True
if no+1 == last:
tokens.append((span, ent, colors[ent],'#464646'))
if (add_span is False) & (len(span) >1):
tokens.append((span, p_ent, colors[p_ent],'#464646'))
span = ''
p_ent = None
if add_span is False:
tokens.append(" " + token.text + " ")
return tokens
def process_text_compare(infer_input: dict, selected_entities: list, colors: list)-> list:
"""
This function is use when user is looking to compare the text annotations between the prediction and
labels. This function is to process the tokens from evaluation data such that tokens that
are grouped together by their corresponding entities. This allow the st-annotations to be processed
the tokens for visualization
Example: "Hi John, i am sick with cough and flu"
Entities: person , disease
Output: [(Hi)(John, 'person', blue)(i am sick)(cough, 'disease', red)(and)(flu, 'disease', red)]
Parameters:
infer_input : spacy document
selected_entities : list of entities
colors : list of colors
Returns:
tokens: list of tuples
"""
tokens = []
start_=0
end_= len(infer_input['text'])
for start, end, entities in infer_input['entities']:
if entities in selected_entities:
# get the span of words that match the entities detected
span = infer_input['text'][start:end+1]
# get the span of words that don't match the entities
if start_ != start:
b4_span = infer_input['text'][start_:start]
tokens.append(" " + b4_span + " ")
tokens.append((span, entities, colors[entities],'#464646'))
start_=end
if start_ <= end_:
span = infer_input['text'][start_:end_+1]
tokens.append(" " + span + " ")
return tokens
def process_files(uploaded_file, text_input):
"""
As the app allows uploading files of mutiple files types, at present
such as json, csv, pdf and txt format.
The function is to detect what kind of file has been uploaded and process
the files accordingly.
If file has been uplaoded it will replace existing text_input
Parameters:
uploaded_file: The UploadedFile class is a subclass of BytesIO, and therefore it is "file-like".
text_input: str / dict /list
Return:
text_input: list / dict / str
"""
if uploaded_file is not None:
if uploaded_file.name[-3:]=='csv':
# literal_eval to eval a string of list into actual list obj
text_input = pd.read_csv(uploaded_file, converters={'entities': ast.literal_eval})
text_input = text_input.to_dict('records')
elif uploaded_file.name[-3:]=='son':
text_input = json.load(uploaded_file)
else:
try:
text_input = ""
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
for line in stringio.readlines():
text_input += line + "\n"
#text_input = text_input.decode("utf-8", errors='strict')
except:
text_input = []
reader = PdfReader(uploaded_file)
count = len(reader.pages)
# read all the pages of a pdf
for i in range(count):
pages = reader.pages[i]
text_input.append(pages.extract_text())
text_input = ''.join(text_input)
return text_input