calerio's picture
Update app.py
2858d67 verified
raw
history blame
4.92 kB
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import shap
from shap.maskers import Text as MaskerText
import numpy as np
import matplotlib.pyplot as plt
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Model
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model")
# NER
ner = pipeline("ner", model="d4data/biomedical-ner-all", tokenizer="d4data/biomedical-ner-all",
aggregation_strategy="simple", device=0 if device.type == "cuda" else -1)
# SHAP pipeline
clf_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer,
top_k=None, device=0 if device.type == "cuda" else -1)
masker = MaskerText(tokenizer)
def shap_predict(texts):
texts = [str(t) for t in (texts.tolist() if isinstance(texts, np.ndarray) else texts)]
results = clf_pipeline(texts, truncation=True, padding=True, max_length=512)
return np.array([[entry['score'] for entry in sample] for sample in results])
explainer = shap.Explainer(shap_predict, masker, algorithm="kernel", output_names=["Not Severe", "Severe"])
# Tags
SYMPTOM_TAGS = {"sign_symptom", "symptom"}
DISEASE_TAGS = {"disease_disorder"}
MED_TAGS = {"medication", "administration", "therapeutic_procedure"}
def dedupe_and_filter(tokens):
seen, out = set(), []
for w in tokens:
clean_w = w.strip()
if len(clean_w) < 3:
continue
lw = clean_w.lower()
if lw not in seen:
seen.add(lw)
out.append(clean_w)
return out
def classify_adr(text, show_confidence, show_shap):
clean = text.strip().replace("nan", " ").replace(" ", " ")
inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
ents = ner(clean)
spans = []
for e in ents:
grp, st, en, score = e['entity_group'].lower(), e['start'], e['end'], e.get('score', 1.0)
if spans and spans[-1]['group'] == grp and st <= spans[-1]['end']:
spans[-1]['end'] = max(spans[-1]['end'], en)
spans[-1]['score'] = max(spans[-1]['score'], score)
else:
spans.append({'group': grp, 'start': st, 'end': en, 'score': score})
for s in spans:
if s['group'] in MED_TAGS:
while s['end'] < len(clean) and clean[s['end']].isalpha():
s['end'] += 1
spans = [s for s in spans if s['score'] >= 0.6]
toks = [clean[s['start']:s['end']] for s in spans]
symptoms = dedupe_and_filter([t for t, s in zip(toks, spans) if s['group'] in SYMPTOM_TAGS])
diseases = dedupe_and_filter([t for t, s in zip(toks, spans) if s['group'] in DISEASE_TAGS])
medications = dedupe_and_filter([t for t, s in zip(toks, spans) if s['group'] in MED_TAGS])
interp = ("❗ High confidence this is a severe ADR." if probs[1] > 0.9 else
"⚠️ Borderline case β€” may be severe." if probs[1] > 0.5 else
"βœ… Likely not severe.")
conf_str = f"**Not Severe:** {probs[0]:.3f} \n**Severe:** {probs[1]:.3f}" if show_confidence else ""
indicators = symptoms + diseases + medications
why = 'Possible indicators: ' + ", ".join(indicators) if probs[1] > 0.5 and indicators else 'No severe indicators.'
if show_shap:
shap_vals = explainer([clean])
fig, ax = plt.subplots(figsize=(8, 4))
shap.plots.bar(shap_vals[0], show=False, ax=ax)
plt.tight_layout()
else:
fig = None
return (
conf_str,
interp,
why,
'\n'.join(symptoms) or 'None detected',
'\n'.join(diseases) or 'None detected',
'\n'.join(medications) or 'None detected',
fig
)
# Gradio UI
demo = gr.Interface(
fn=classify_adr,
inputs=[
gr.Textbox(lines=4, label="ADR Description"),
gr.Checkbox(label="Show Prediction Confidence", value=True),
gr.Checkbox(label="Generate SHAP Explanation (slower)", value=False)
],
outputs=[
gr.Markdown(label="Predicted Probabilities"),
gr.Markdown(label="Interpretation"),
gr.Markdown(label="Why is this severe?"),
gr.Textbox(label="Symptoms"),
gr.Textbox(label="Diseases or Conditions"),
gr.Textbox(label="Medications"),
gr.Plot(label="SHAP Bar Plot")
],
title="ADR Severity & NER Classifier 2",
description="Paste an ADR description. The model predicts severity and extracts medical terms. SHAP analysis available (optional, slow).",
allow_flagging="never"
)
if __name__ == '__main__':
demo.launch()