Spaces:
Build error
Build error
print('INFO: import modules') | |
import json | |
import gradio as gr | |
import pickle | |
from required_classes import * | |
print('INFO: loading model') | |
try: | |
with open('model_finetuned_clear.pkl', 'rb') as f: | |
model = pickle.load(f) | |
model.batch_size = 1 | |
print('INFO: model loaded') | |
except Exception as e: | |
print(f"ERROR: loading models failed with: {str(e)}") | |
def classify_code(text, top_n): | |
embed = model._texts2vecs([text]) | |
probs = model.classifier_code.predict_proba(embed) | |
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) | |
preds = {model.classifier_code.classes_[i]: probs[0][i] for i in best_n} | |
return preds | |
def classify_group(text, top_n): | |
embed = model._texts2vecs([text]) | |
probs = model.classifier_group.predict_proba(embed) | |
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) | |
preds = {model.classifier_group.classes_[i]: probs[0][i] for i in best_n} | |
return preds | |
def classify(text, top_n): | |
try: | |
top_n = int(top_n) | |
res = classify_code(text, top_n), classify_group(text, top_n) | |
return res | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
return error_msg, error_msg | |
print('INFO: starting gradio interface') | |
box_class = gr.Label(label="Result class") | |
box_group = gr.Label(label="Result group") | |
def predict(text, top_n): | |
try: | |
top_n = int(top_n) | |
predicted_codes = classify_code(text, top_n) | |
predicted_groups = classify_group(text, top_n) | |
return {box_class: predicted_codes, box_group: predicted_groups} | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
return {box_class: error_msg, box_group: error_msg} | |
default_input_text = json.load(open('default_input.json'))['input_text'] | |
iface = gr.Interface( | |
enable_queue=True, | |
title="ICD10-codes classification", | |
description="", | |
fn=predict, | |
inputs=[gr.Textbox(label="Input text", value=default_input_text), gr.Number(label="TOP-N candidates", value=3)], | |
outputs=[box_class, box_group], | |
) | |
iface.launch() | |