demo_icd10 / app.py
lyangas
change interface of gradio
860fd6c
print('INFO: import modules')
import json
import gradio as gr
import pickle
from required_classes import *
print('INFO: loading model')
try:
with open('model_finetuned_clear.pkl', 'rb') as f:
model = pickle.load(f)
model.batch_size = 1
print('INFO: model loaded')
except Exception as e:
print(f"ERROR: loading models failed with: {str(e)}")
def classify_code(text, top_n):
embed = model._texts2vecs([text])
probs = model.classifier_code.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
preds = {model.classifier_code.classes_[i]: probs[0][i] for i in best_n}
return preds
def classify_group(text, top_n):
embed = model._texts2vecs([text])
probs = model.classifier_group.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
preds = {model.classifier_group.classes_[i]: probs[0][i] for i in best_n}
return preds
def classify(text, top_n):
try:
top_n = int(top_n)
res = classify_code(text, top_n), classify_group(text, top_n)
return res
except Exception as e:
error_msg = f"Error: {str(e)}"
return error_msg, error_msg
print('INFO: starting gradio interface')
box_class = gr.Label(label="Result class")
box_group = gr.Label(label="Result group")
def predict(text, top_n):
try:
top_n = int(top_n)
predicted_codes = classify_code(text, top_n)
predicted_groups = classify_group(text, top_n)
return {box_class: predicted_codes, box_group: predicted_groups}
except Exception as e:
error_msg = f"Error: {str(e)}"
return {box_class: error_msg, box_group: error_msg}
default_input_text = json.load(open('default_input.json'))['input_text']
iface = gr.Interface(
enable_queue=True,
title="ICD10-codes classification",
description="",
fn=predict,
inputs=[gr.Textbox(label="Input text", value=default_input_text), gr.Number(label="TOP-N candidates", value=3)],
outputs=[box_class, box_group],
)
iface.launch()