|
|
import os |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
import shap |
|
|
from shap.maskers import Text as MaskerText |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model") |
|
|
model.to(device) |
|
|
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model") |
|
|
|
|
|
|
|
|
ner = pipeline("ner", model="d4data/biomedical-ner-all", tokenizer="d4data/biomedical-ner-all", |
|
|
aggregation_strategy="simple", device=0 if device.type == "cuda" else -1) |
|
|
|
|
|
|
|
|
clf_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer, |
|
|
top_k=None, device=0 if device.type == "cuda" else -1) |
|
|
masker = MaskerText(tokenizer) |
|
|
def shap_predict(texts): |
|
|
texts = [str(t) for t in (texts.tolist() if isinstance(texts, np.ndarray) else texts)] |
|
|
results = clf_pipeline(texts, truncation=True, padding=True, max_length=512) |
|
|
return np.array([[entry['score'] for entry in sample] for sample in results]) |
|
|
explainer = shap.Explainer(shap_predict, masker, algorithm="kernel", output_names=["Not Severe", "Severe"]) |
|
|
|
|
|
|
|
|
SYMPTOM_TAGS = {"sign_symptom", "symptom"} |
|
|
DISEASE_TAGS = {"disease_disorder"} |
|
|
MED_TAGS = {"medication", "administration", "therapeutic_procedure"} |
|
|
|
|
|
def dedupe_and_filter(tokens): |
|
|
seen, out = set(), [] |
|
|
for w in tokens: |
|
|
clean_w = w.strip() |
|
|
if len(clean_w) < 3: |
|
|
continue |
|
|
lw = clean_w.lower() |
|
|
if lw not in seen: |
|
|
seen.add(lw) |
|
|
out.append(clean_w) |
|
|
return out |
|
|
|
|
|
def classify_adr(text, show_confidence, show_shap): |
|
|
clean = text.strip().replace("nan", " ").replace(" ", " ") |
|
|
|
|
|
inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) |
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
probs = torch.softmax(logits, dim=1)[0].cpu().numpy() |
|
|
|
|
|
ents = ner(clean) |
|
|
spans = [] |
|
|
for e in ents: |
|
|
grp, st, en, score = e['entity_group'].lower(), e['start'], e['end'], e.get('score', 1.0) |
|
|
if spans and spans[-1]['group'] == grp and st <= spans[-1]['end']: |
|
|
spans[-1]['end'] = max(spans[-1]['end'], en) |
|
|
spans[-1]['score'] = max(spans[-1]['score'], score) |
|
|
else: |
|
|
spans.append({'group': grp, 'start': st, 'end': en, 'score': score}) |
|
|
for s in spans: |
|
|
if s['group'] in MED_TAGS: |
|
|
while s['end'] < len(clean) and clean[s['end']].isalpha(): |
|
|
s['end'] += 1 |
|
|
spans = [s for s in spans if s['score'] >= 0.6] |
|
|
toks = [clean[s['start']:s['end']] for s in spans] |
|
|
|
|
|
symptoms = dedupe_and_filter([t for t, s in zip(toks, spans) if s['group'] in SYMPTOM_TAGS]) |
|
|
diseases = dedupe_and_filter([t for t, s in zip(toks, spans) if s['group'] in DISEASE_TAGS]) |
|
|
medications = dedupe_and_filter([t for t, s in zip(toks, spans) if s['group'] in MED_TAGS]) |
|
|
|
|
|
interp = ("β High confidence this is a severe ADR." if probs[1] > 0.9 else |
|
|
"β οΈ Borderline case β may be severe." if probs[1] > 0.5 else |
|
|
"β
Likely not severe.") |
|
|
conf_str = f"**Not Severe:** {probs[0]:.3f} \n**Severe:** {probs[1]:.3f}" if show_confidence else "" |
|
|
indicators = symptoms + diseases + medications |
|
|
why = 'Possible indicators: ' + ", ".join(indicators) if probs[1] > 0.5 and indicators else 'No severe indicators.' |
|
|
|
|
|
if show_shap: |
|
|
shap_vals = explainer([clean]) |
|
|
fig, ax = plt.subplots(figsize=(8, 4)) |
|
|
shap.plots.bar(shap_vals[0], show=False, ax=ax) |
|
|
plt.tight_layout() |
|
|
else: |
|
|
fig = None |
|
|
|
|
|
return ( |
|
|
conf_str, |
|
|
interp, |
|
|
why, |
|
|
'\n'.join(symptoms) or 'None detected', |
|
|
'\n'.join(diseases) or 'None detected', |
|
|
'\n'.join(medications) or 'None detected', |
|
|
fig |
|
|
) |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=classify_adr, |
|
|
inputs=[ |
|
|
gr.Textbox(lines=4, label="ADR Description"), |
|
|
gr.Checkbox(label="Show Prediction Confidence", value=True), |
|
|
gr.Checkbox(label="Generate SHAP Explanation (slower)", value=False) |
|
|
], |
|
|
outputs=[ |
|
|
gr.Markdown(label="Predicted Probabilities"), |
|
|
gr.Markdown(label="Interpretation"), |
|
|
gr.Markdown(label="Why is this severe?"), |
|
|
gr.Textbox(label="Symptoms"), |
|
|
gr.Textbox(label="Diseases or Conditions"), |
|
|
gr.Textbox(label="Medications"), |
|
|
gr.Plot(label="SHAP Bar Plot") |
|
|
], |
|
|
title="ADR Severity & NER Classifier 2", |
|
|
description="Paste an ADR description. The model predicts severity and extracts medical terms. SHAP analysis available (optional, slow).", |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
demo.launch() |