Spaces:
Running
Running
ZhaohanM
commited on
Commit
·
5676c75
1
Parent(s):
fde18d9
Update: SMILES-to-SELFIES conversion, UI polish, and usage guide
Browse files- .ipynb_checkpoints/app-checkpoint.py +472 -0
- .ipynb_checkpoints/requirements-checkpoint.txt +11 -0
- app.py +432 -193
- requirements.txt +8 -2
- utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py +73 -0
- utils/.ipynb_checkpoints/metric_learning_models_att_maps-checkpoint.py +325 -0
- utils/__pycache__/foldseek_util.cpython-38.pyc +0 -0
- utils/__pycache__/metric_learning_models_att_maps.cpython-38.pyc +0 -0
- utils/drug_tokenizer.py +8 -1
- utils/foldseek_util.py +167 -0
- utils/metric_learning_models_att_maps.py +2 -7
.ipynb_checkpoints/app-checkpoint.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, argparse, tempfile, shutil, base64, io
|
| 2 |
+
from flask import Flask, request, render_template_string
|
| 3 |
+
from werkzeug.utils import secure_filename
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import selfies
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use("Agg")
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
from matplotlib import cm
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from utils.drug_tokenizer import DrugTokenizer
|
| 16 |
+
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
|
| 17 |
+
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
|
| 18 |
+
from utils.foldseek_util import get_struc_seq
|
| 19 |
+
|
| 20 |
+
# ───── Biopython fallback ───────────────────────────────────────
|
| 21 |
+
from Bio.PDB import PDBParser, MMCIFParser
|
| 22 |
+
from Bio.Data import IUPACData
|
| 23 |
+
|
| 24 |
+
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
|
| 25 |
+
three2one.update({"SEC": "C", "PYL": "K"})
|
| 26 |
+
def simple_seq_from_structure(path: str) -> str:
|
| 27 |
+
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
|
| 28 |
+
chain = next(parser.get_structure("P", path).get_chains())
|
| 29 |
+
return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
|
| 30 |
+
|
| 31 |
+
# ───── global paths / args ──────────────────────────────────────
|
| 32 |
+
FOLDSEEK_BIN = shutil.which("foldseek")
|
| 33 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 34 |
+
sys.path.append("..")
|
| 35 |
+
|
| 36 |
+
def parse_config():
|
| 37 |
+
p = argparse.ArgumentParser()
|
| 38 |
+
p.add_argument("-f")
|
| 39 |
+
p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
|
| 40 |
+
p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
|
| 41 |
+
p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
|
| 42 |
+
p.add_argument("--group_size", type=int, default=1)
|
| 43 |
+
p.add_argument("--lr", type=float, default=1e-4)
|
| 44 |
+
p.add_argument("--fusion", default="CAN")
|
| 45 |
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 46 |
+
p.add_argument("--save_path_prefix", default="save_model_ckp/")
|
| 47 |
+
p.add_argument("--dataset", default="BindingDB"), help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')"
|
| 48 |
+
return p.parse_args()
|
| 49 |
+
|
| 50 |
+
args = parse_config()
|
| 51 |
+
DEVICE = args.device
|
| 52 |
+
|
| 53 |
+
# ───── tokenisers & encoders ────────────────────────────────────
|
| 54 |
+
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
|
| 55 |
+
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
|
| 56 |
+
|
| 57 |
+
drug_tokenizer = DrugTokenizer() # SELFIES
|
| 58 |
+
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
|
| 59 |
+
|
| 60 |
+
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
|
| 61 |
+
|
| 62 |
+
# ─── collate fn ────────────────────────────────────────────────
|
| 63 |
+
def collate_fn(batch):
|
| 64 |
+
query1, query2, scores = zip(*batch)
|
| 65 |
+
|
| 66 |
+
query_encodings1 = prot_tokenizer.batch_encode_plus(
|
| 67 |
+
list(query1),
|
| 68 |
+
max_length=512,
|
| 69 |
+
padding="max_length",
|
| 70 |
+
truncation=True,
|
| 71 |
+
add_special_tokens=True,
|
| 72 |
+
return_tensors="pt",
|
| 73 |
+
)
|
| 74 |
+
query_encodings2 = drug_tokenizer.batch_encode_plus(
|
| 75 |
+
list(query2),
|
| 76 |
+
max_length=512,
|
| 77 |
+
padding="max_length",
|
| 78 |
+
truncation=True,
|
| 79 |
+
add_special_tokens=True,
|
| 80 |
+
return_tensors="pt",
|
| 81 |
+
)
|
| 82 |
+
scores = torch.tensor(list(scores))
|
| 83 |
+
|
| 84 |
+
attention_mask1 = query_encodings1["attention_mask"].bool()
|
| 85 |
+
attention_mask2 = query_encodings2["attention_mask"].bool()
|
| 86 |
+
|
| 87 |
+
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
|
| 88 |
+
# def collate_fn_batch_encoding(batch):
|
| 89 |
+
|
| 90 |
+
def smiles_to_selfies(smiles: str) -> Optional[str]:
|
| 91 |
+
try:
|
| 92 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 93 |
+
if mol is None:
|
| 94 |
+
return None
|
| 95 |
+
selfies_str = selfies.encoder(smiles)
|
| 96 |
+
return selfies_str
|
| 97 |
+
except Exception:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ───── single-case embedding ───────────────────────────────────
|
| 102 |
+
def get_case_feature(model, loader):
|
| 103 |
+
model.eval()
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
for p_ids, p_mask, d_ids, d_mask, _ in loader:
|
| 106 |
+
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
|
| 107 |
+
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
|
| 108 |
+
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
|
| 109 |
+
return [(p_emb.cpu(), d_emb.cpu(),
|
| 110 |
+
p_ids.cpu(), d_ids.cpu(),
|
| 111 |
+
p_mask.cpu(), d_mask.cpu(), None)]
|
| 112 |
+
|
| 113 |
+
# ───── helper:过滤特殊 token ───────────────────────────────────
|
| 114 |
+
def clean_tokens(ids, tokenizer):
|
| 115 |
+
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
|
| 116 |
+
return [t for t in toks if t not in tokenizer.all_special_tokens]
|
| 117 |
+
|
| 118 |
+
# ───── visualisation ───────────────────────────────────────────
|
| 119 |
+
|
| 120 |
+
def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
|
| 121 |
+
"""
|
| 122 |
+
Render a Protein → Drug cross-attention heat-map and, optionally, a
|
| 123 |
+
Top-20 protein-residue table for a chosen drug-token index.
|
| 124 |
+
|
| 125 |
+
The token index shown on the x-axis (and accepted via *drug_idx*) is **the
|
| 126 |
+
position of that token in the *original* drug sequence**, *after* the
|
| 127 |
+
tokeniser but *before* any pruning or truncation (1-based in the labels,
|
| 128 |
+
0-based for the function argument).
|
| 129 |
+
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
html : str
|
| 133 |
+
Base64-embedded PNG heat-map (+ optional HTML table).
|
| 134 |
+
"""
|
| 135 |
+
model.eval()
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
# ── unpack single-case tensors ───────────────────────────────────────────
|
| 138 |
+
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
|
| 139 |
+
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
|
| 140 |
+
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
|
| 141 |
+
|
| 142 |
+
# ── forward pass: Protein → Drug attention (B, n_p, n_d) ───────────────
|
| 143 |
+
_, att_pd = model(p_emb, d_emb, p_mask, d_mask)
|
| 144 |
+
attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
|
| 145 |
+
|
| 146 |
+
# ── decode tokens (skip special symbols) ────────────────────────────────
|
| 147 |
+
def clean_ids(ids, tokenizer):
|
| 148 |
+
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
|
| 149 |
+
return [t for t in toks if t not in tokenizer.all_special_tokens]
|
| 150 |
+
|
| 151 |
+
# ── decode full sequences + record 1-based indices ──────────────────
|
| 152 |
+
p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
|
| 153 |
+
p_indices_full = list(range(1, len(p_tokens_full) + 1))
|
| 154 |
+
|
| 155 |
+
d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
|
| 156 |
+
d_indices_full = list(range(1, len(d_tokens_full) + 1))
|
| 157 |
+
|
| 158 |
+
# ── safety cut-off to match attn mat size ───────────────────────────────
|
| 159 |
+
p_tokens = p_tokens_full[: attn.size(0)]
|
| 160 |
+
p_indices_full = p_indices_full[: attn.size(0)]
|
| 161 |
+
d_tokens_full = d_tokens_full[: attn.size(1)]
|
| 162 |
+
d_indices_full = d_indices_full[: attn.size(1)]
|
| 163 |
+
attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
|
| 164 |
+
|
| 165 |
+
# ── adaptive sparsity pruning ───────────────────────────────────────────
|
| 166 |
+
thr = attn.max().item() * 0.05
|
| 167 |
+
row_keep = (attn.max(dim=1).values > thr)
|
| 168 |
+
col_keep = (attn.max(dim=0).values > thr)
|
| 169 |
+
|
| 170 |
+
if row_keep.sum() < 3:
|
| 171 |
+
row_keep[:] = True
|
| 172 |
+
if col_keep.sum() < 3:
|
| 173 |
+
col_keep[:] = True
|
| 174 |
+
|
| 175 |
+
attn = attn[row_keep][:, col_keep]
|
| 176 |
+
p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
|
| 177 |
+
p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
|
| 178 |
+
d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
|
| 179 |
+
d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
|
| 180 |
+
|
| 181 |
+
# ── cap column count at 150 for readability ─────────────────────────────
|
| 182 |
+
if attn.size(1) > 150:
|
| 183 |
+
topc = torch.topk(attn.sum(0), k=150).indices
|
| 184 |
+
attn = attn[:, topc]
|
| 185 |
+
d_tokens = [d_tokens [i] for i in topc]
|
| 186 |
+
d_indices = [d_indices[i] for i in topc]
|
| 187 |
+
|
| 188 |
+
# ── draw heat-map ───────────────────────────────────────────────────────
|
| 189 |
+
x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
|
| 190 |
+
y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column
|
| 194 |
+
fig_h = min(24, max(6, len(p_tokens) * 0.8))
|
| 195 |
+
|
| 196 |
+
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
|
| 197 |
+
im = ax.imshow(attn.numpy(), aspect="auto",
|
| 198 |
+
cmap=cm.viridis, interpolation="nearest")
|
| 199 |
+
|
| 200 |
+
ax.set_title("Protein → Drug Attention", pad=8, fontsize=10)
|
| 201 |
+
|
| 202 |
+
ax.set_xticks(range(len(x_labels)))
|
| 203 |
+
ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
|
| 204 |
+
ha="center", va="center")
|
| 205 |
+
ax.tick_params(axis="x", top=True, bottom=False,
|
| 206 |
+
labeltop=True, labelbottom=False, pad=27)
|
| 207 |
+
|
| 208 |
+
ax.set_yticks(range(len(y_labels)))
|
| 209 |
+
ax.set_yticklabels(y_labels, fontsize=7)
|
| 210 |
+
ax.tick_params(axis="y", top=True, bottom=False,
|
| 211 |
+
labeltop=True, labelbottom=False,
|
| 212 |
+
pad=10)
|
| 213 |
+
|
| 214 |
+
fig.colorbar(im, fraction=0.026, pad=0.01)
|
| 215 |
+
fig.tight_layout()
|
| 216 |
+
|
| 217 |
+
buf = io.BytesIO()
|
| 218 |
+
fig.savefig(buf, format="png", dpi=140)
|
| 219 |
+
plt.close(fig)
|
| 220 |
+
html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
|
| 221 |
+
|
| 222 |
+
# ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
|
| 223 |
+
table_html = "" # 先设空串,方便后面统一拼接
|
| 224 |
+
if drug_idx is not None:
|
| 225 |
+
# map original 0-based drug_idx → current column position
|
| 226 |
+
if (drug_idx + 1) in d_indices:
|
| 227 |
+
col_pos = d_indices.index(drug_idx + 1)
|
| 228 |
+
elif 0 <= drug_idx < len(d_tokens):
|
| 229 |
+
col_pos = drug_idx
|
| 230 |
+
else:
|
| 231 |
+
col_pos = None
|
| 232 |
+
|
| 233 |
+
if col_pos is not None:
|
| 234 |
+
col_vec = attn[:, col_pos]
|
| 235 |
+
topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
|
| 236 |
+
|
| 237 |
+
rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
|
| 238 |
+
res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
|
| 239 |
+
pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk)
|
| 240 |
+
|
| 241 |
+
drug_tok_text = d_tokens[col_pos]
|
| 242 |
+
orig_idx = d_indices[col_pos]
|
| 243 |
+
|
| 244 |
+
table_html = (
|
| 245 |
+
f"<h4 style='margin-bottom:6px'>"
|
| 246 |
+
f"Drug token #{orig_idx} <code>{drug_tok_text}</code> "
|
| 247 |
+
f"→ Top-20 Protein residues</h4>"
|
| 248 |
+
"<table class='tg' style='margin-bottom:8px'>"
|
| 249 |
+
f"<tr><th>Rank</th>{rank_hdr}</tr>"
|
| 250 |
+
f"<tr><td>Residue</td>{res_row}</tr>"
|
| 251 |
+
f"<tr><td>Position</td>{pos_row}</tr>"
|
| 252 |
+
"</table>")
|
| 253 |
+
|
| 254 |
+
# ────────────────── 生成可放大 + 可下载的热图 ────────────────────
|
| 255 |
+
buf_png = io.BytesIO()
|
| 256 |
+
fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅)
|
| 257 |
+
buf_png.seek(0)
|
| 258 |
+
|
| 259 |
+
buf_pdf = io.BytesIO()
|
| 260 |
+
fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量)
|
| 261 |
+
buf_pdf.seek(0)
|
| 262 |
+
plt.close(fig)
|
| 263 |
+
|
| 264 |
+
png_b64 = base64.b64encode(buf_png.getvalue()).decode()
|
| 265 |
+
pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
|
| 266 |
+
|
| 267 |
+
html_heat = (
|
| 268 |
+
f"<a href='data:image/png;base64,{png_b64}' target='_blank' "
|
| 269 |
+
f"title='Click to enlarge'>"
|
| 270 |
+
f"<img src='data:image/png;base64,{png_b64}' "
|
| 271 |
+
f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>"
|
| 272 |
+
f"<div style='margin-top:6px'>"
|
| 273 |
+
f"<a href='data:application/pdf;base64,{pdf_b64}' "
|
| 274 |
+
f"download='attention_heatmap.pdf'>Download PDF</a></div>"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# ───────────────────────── 返回最终 HTML ─────────────────────────
|
| 278 |
+
return table_html + html_heat
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ───── Flask app ───────────────────────────────────────────────
|
| 282 |
+
app = Flask(__name__)
|
| 283 |
+
|
| 284 |
+
@app.route("/", methods=["GET", "POST"])
|
| 285 |
+
def index():
|
| 286 |
+
protein_seq = drug_seq = structure_seq = ""; result_html = None
|
| 287 |
+
tmp_structure_path = ""; drug_idx = None
|
| 288 |
+
|
| 289 |
+
if request.method == "POST":
|
| 290 |
+
drug_idx_raw = request.form.get("drug_idx", "")
|
| 291 |
+
drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None
|
| 292 |
+
|
| 293 |
+
struct = request.files.get("structure_file")
|
| 294 |
+
if struct and struct.filename:
|
| 295 |
+
path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename))
|
| 296 |
+
struct.save(path); tmp_structure_path = path
|
| 297 |
+
else:
|
| 298 |
+
tmp_structure_path = request.form.get("tmp_structure_path", "")
|
| 299 |
+
|
| 300 |
+
if "clear" in request.form:
|
| 301 |
+
protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = ""
|
| 302 |
+
|
| 303 |
+
elif "confirm_structure" in request.form and tmp_structure_path:
|
| 304 |
+
try:
|
| 305 |
+
parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, None, plddt_mask=False)
|
| 306 |
+
chain = list(parsed.keys())[0]; _, _, structure_seq = parsed[chain]
|
| 307 |
+
except Exception:
|
| 308 |
+
structure_seq = simple_seq_from_structure(tmp_structure_path)
|
| 309 |
+
protein_seq = structure_seq
|
| 310 |
+
drug_input = request.form.get("drug_sequence", "")
|
| 311 |
+
# Heuristically check if input is SMILES (not starting with [) and convert
|
| 312 |
+
if not drug_input.strip().startswith("["):
|
| 313 |
+
converted = smiles_to_selfies(drug_input.strip())
|
| 314 |
+
if converted:
|
| 315 |
+
drug_seq = converted
|
| 316 |
+
else:
|
| 317 |
+
drug_seq = ""
|
| 318 |
+
result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>"
|
| 319 |
+
else:
|
| 320 |
+
drug_seq = drug_input
|
| 321 |
+
|
| 322 |
+
elif "Inference" in request.form:
|
| 323 |
+
protein_seq = request.form.get("protein_sequence", "")
|
| 324 |
+
drug_seq = request.form.get("drug_sequence", "")
|
| 325 |
+
if protein_seq and drug_seq:
|
| 326 |
+
loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1,
|
| 327 |
+
collate_fn=collate_fn)
|
| 328 |
+
feats = get_case_feature(encoding, loader)
|
| 329 |
+
model = FusionDTI(446, 768, args).to(DEVICE)
|
| 330 |
+
ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}",
|
| 331 |
+
"best_model.ckpt")
|
| 332 |
+
if os.path.isfile(ckpt):
|
| 333 |
+
model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
|
| 334 |
+
result_html = visualize_attention(model, feats, drug_idx)
|
| 335 |
+
|
| 336 |
+
return render_template_string(
|
| 337 |
+
# ───────────── HTML (原 UI + 新输入框) ─────────────
|
| 338 |
+
"""
|
| 339 |
+
<!doctype html>
|
| 340 |
+
<html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title>
|
| 341 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet">
|
| 342 |
+
|
| 343 |
+
<style>
|
| 344 |
+
:root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;}
|
| 345 |
+
*{box-sizing:border-box;margin:0;padding:0}
|
| 346 |
+
body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;}
|
| 347 |
+
h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;}
|
| 348 |
+
.card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border);
|
| 349 |
+
border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;}
|
| 350 |
+
label{font-weight:500;margin-bottom:6px;display:block}
|
| 351 |
+
textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px;
|
| 352 |
+
border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;}
|
| 353 |
+
textarea{min-height:90px}
|
| 354 |
+
.btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500;
|
| 355 |
+
font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;}
|
| 356 |
+
.btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)}
|
| 357 |
+
.btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280}
|
| 358 |
+
.grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr}
|
| 359 |
+
.vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;}
|
| 360 |
+
pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px}
|
| 361 |
+
|
| 362 |
+
/* ── tidy table for Top-20 list ─────────────────────────────── */
|
| 363 |
+
table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem}
|
| 364 |
+
table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left}
|
| 365 |
+
table.tg th{background:var(--bg);font-weight:600}
|
| 366 |
+
</style>
|
| 367 |
+
</head>
|
| 368 |
+
<body>
|
| 369 |
+
<h1> Token-level Visualiser for Drug-Target Interaction</h1>
|
| 370 |
+
|
| 371 |
+
<!-- ───────────── Project Links (larger + spaced) ───────────── -->
|
| 372 |
+
<div style="margin-top:24px; text-align:center;">
|
| 373 |
+
<a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank"
|
| 374 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
| 375 |
+
background:linear-gradient(to right,#10b981,#059669);color:white;
|
| 376 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
| 377 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
| 378 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
| 379 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
| 380 |
+
🌐 Project Page
|
| 381 |
+
</a>
|
| 382 |
+
|
| 383 |
+
<a href="https://arxiv.org/abs/2406.01651" target="_blank"
|
| 384 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
| 385 |
+
background:linear-gradient(to right,#ef4444,#dc2626);color:white;
|
| 386 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
| 387 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
| 388 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
| 389 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
| 390 |
+
📄 ArXiv: 2406.01651
|
| 391 |
+
</a>
|
| 392 |
+
|
| 393 |
+
<a href="https://github.com/ZhaohanM/FusionDTI" target="_blank"
|
| 394 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
| 395 |
+
background:linear-gradient(to right,#3b82f6,#2563eb);color:white;
|
| 396 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
| 397 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
| 398 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
| 399 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
| 400 |
+
💻 GitHub Repo
|
| 401 |
+
</a>
|
| 402 |
+
</div>
|
| 403 |
+
|
| 404 |
+
<!-- ───────────── Guidelines for Use ───────────── -->
|
| 405 |
+
<div class="card" style="margin-bottom:24px">
|
| 406 |
+
<h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2>
|
| 407 |
+
<ul style="margin-left:18px;line-height:1.55;list-style:decimal;">
|
| 408 |
+
<li><strong>Convert protein structure into a structure-aware sequence:</strong>
|
| 409 |
+
Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
|
| 410 |
+
sequence will be generated using
|
| 411 |
+
<a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
|
| 412 |
+
based on 3D structures from
|
| 413 |
+
<a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> or the
|
| 414 |
+
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
|
| 415 |
+
|
| 416 |
+
<li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
|
| 417 |
+
you must first visit the
|
| 418 |
+
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
|
| 419 |
+
or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a>
|
| 420 |
+
to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
|
| 421 |
+
|
| 422 |
+
<li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
|
| 423 |
+
You can enter a SELFIES string directly, or paste a SMILES string.
|
| 424 |
+
SMILES will be automatically converted to SELFIES using
|
| 425 |
+
<a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
|
| 426 |
+
If conversion fails, a red error message will be displayed.</li>
|
| 427 |
+
|
| 428 |
+
<li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
|
| 429 |
+
to highlight the Top-10 interacting protein residues.</li>
|
| 430 |
+
|
| 431 |
+
<li>After inference, you can use the
|
| 432 |
+
“Download PDF” link to export a high-resolution vector version.</li>
|
| 433 |
+
</ul>
|
| 434 |
+
</div>
|
| 435 |
+
|
| 436 |
+
<div class="card">
|
| 437 |
+
<form method="POST" enctype="multipart/form-data" class="grid">
|
| 438 |
+
|
| 439 |
+
<div><label>Protein Structure (.pdb / .cif)</label>
|
| 440 |
+
<input type="file" name="structure_file">
|
| 441 |
+
<input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div>
|
| 442 |
+
|
| 443 |
+
<div><label>Protein Sequence</label>
|
| 444 |
+
<textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div>
|
| 445 |
+
|
| 446 |
+
<div><label>Drug Sequence (SELFIES/SMILES)</label>
|
| 447 |
+
<textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div>
|
| 448 |
+
|
| 449 |
+
<label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label>
|
| 450 |
+
<input type="number" name="drug_idx" min="1" style="width:120px">
|
| 451 |
+
|
| 452 |
+
<div class="grid grid-2">
|
| 453 |
+
<button class="btn btn-primary" type="Inference" name="confirm_structure">Confirm Structure</button>
|
| 454 |
+
<button class="btn btn-primary" type="Inference" name="Inference">Inference</button>
|
| 455 |
+
</div>
|
| 456 |
+
<button class="btn btn-neutral" style="width:100%" type="Inference" name="clear">Clear</button>
|
| 457 |
+
</form>
|
| 458 |
+
|
| 459 |
+
{% if structure_seq %}
|
| 460 |
+
<div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div>
|
| 461 |
+
{% endif %}
|
| 462 |
+
{% if result_html %}
|
| 463 |
+
<div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div>
|
| 464 |
+
{% endif %}
|
| 465 |
+
</div></body></html>
|
| 466 |
+
""",
|
| 467 |
+
protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq,
|
| 468 |
+
result_html=result_html, tmp_structure_path=tmp_structure_path)
|
| 469 |
+
|
| 470 |
+
# ───── run ─────────────────────────────────────────────────────
|
| 471 |
+
if __name__ == "__main__":
|
| 472 |
+
app.run(debug=True, host="0.0.0.0", port=7860)
|
.ipynb_checkpoints/requirements-checkpoint.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Flask
|
| 2 |
+
torch
|
| 3 |
+
transformers
|
| 4 |
+
IPython
|
| 5 |
+
selfies
|
| 6 |
+
rdkit
|
| 7 |
+
biopython
|
| 8 |
+
matplotlib
|
| 9 |
+
scikit-learn
|
| 10 |
+
numpy
|
| 11 |
+
pandas
|
app.py
CHANGED
|
@@ -1,209 +1,66 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import
|
| 3 |
-
import
|
| 4 |
-
import torch
|
| 5 |
from torch.utils.data import DataLoader
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from utils.drug_tokenizer import DrugTokenizer
|
|
|
|
| 8 |
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
|
| 9 |
-
from
|
| 10 |
-
import tempfile
|
| 11 |
-
from flask import Flask, request, render_template_string
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def parse_config():
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
-
parser.add_argument("--save_path_prefix", type=str, default="save_model_ckp/", help="save the result in which directory")
|
| 33 |
-
parser.add_argument("--save_name", default="fine_tune", type=str, help="the name of the saved file")
|
| 34 |
-
parser.add_argument("--dataset", type=str, default="Human", help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')")
|
| 35 |
-
return parser.parse_args()
|
| 36 |
|
| 37 |
args = parse_config()
|
| 38 |
-
|
| 39 |
|
|
|
|
| 40 |
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
|
| 44 |
-
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
|
| 45 |
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
with torch.no_grad():
|
| 50 |
-
for step, batch in enumerate(dataloader):
|
| 51 |
-
prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask, label = batch
|
| 52 |
-
prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask = \
|
| 53 |
-
prot_input_ids.to(device), prot_attention_mask.to(device), drug_input_ids.to(device), drug_attention_mask.to(device)
|
| 54 |
-
|
| 55 |
-
prot_embed, drug_embed = model.encoding(prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask)
|
| 56 |
-
prot_embed, drug_embed = prot_embed.cpu(), drug_embed.cpu()
|
| 57 |
-
prot_input_ids, drug_input_ids = prot_input_ids.cpu(), drug_input_ids.cpu()
|
| 58 |
-
prot_attention_mask, drug_attention_mask = prot_attention_mask.cpu(), drug_attention_mask.cpu()
|
| 59 |
-
label = label.cpu()
|
| 60 |
-
|
| 61 |
-
return [(prot_embed, drug_embed, prot_input_ids, drug_input_ids, prot_attention_mask, drug_attention_mask, label)]
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
with torch.no_grad():
|
| 66 |
-
for batch in case_features:
|
| 67 |
-
prot, drug, prot_ids, drug_ids, prot_mask, drug_mask, label = batch
|
| 68 |
-
prot, drug = prot.to(device), drug.to(device)
|
| 69 |
-
prot_mask, drug_mask = prot_mask.to(device), drug_mask.to(device)
|
| 70 |
-
|
| 71 |
-
output, attention_weights = model(prot, drug, prot_mask, drug_mask)
|
| 72 |
-
prot_tokens = [prot_tokenizer.decode([pid.item()], skip_special_tokens=True) for pid in prot_ids.squeeze()]
|
| 73 |
-
drug_tokens = [drug_tokenizer.decode([did.item()], skip_special_tokens=True) for did in drug_ids.squeeze()]
|
| 74 |
-
tokens = prot_tokens + drug_tokens
|
| 75 |
-
|
| 76 |
-
attention_weights = attention_weights.unsqueeze(1)
|
| 77 |
-
|
| 78 |
-
# Generate HTML content using head_view with html_action='return'
|
| 79 |
-
html_head_view = head_view(attention_weights, tokens, sentence_b_start=512, html_action='return')
|
| 80 |
-
|
| 81 |
-
# Parse the HTML and modify it to replace sentence labels
|
| 82 |
-
html_content = html_head_view.data
|
| 83 |
-
html_content = html_content.replace("Sentence A -> Sentence A", "Protein -> Protein")
|
| 84 |
-
html_content = html_content.replace("Sentence B -> Sentence B", "Drug -> Drug")
|
| 85 |
-
html_content = html_content.replace("Sentence A -> Sentence B", "Protein -> Drug")
|
| 86 |
-
html_content = html_content.replace("Sentence B -> Sentence A", "Drug -> Protein")
|
| 87 |
-
|
| 88 |
-
# Save the modified HTML content to a temporary file
|
| 89 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
|
| 90 |
-
f.write(html_content.encode('utf-8'))
|
| 91 |
-
temp_file_path = f.name
|
| 92 |
-
|
| 93 |
-
return temp_file_path
|
| 94 |
-
|
| 95 |
-
@app.route('/', methods=['GET', 'POST'])
|
| 96 |
-
def index():
|
| 97 |
-
protein_sequence = ""
|
| 98 |
-
drug_sequence = ""
|
| 99 |
-
result = None
|
| 100 |
-
|
| 101 |
-
if request.method == 'POST':
|
| 102 |
-
if 'clear' in request.form:
|
| 103 |
-
protein_sequence = ""
|
| 104 |
-
drug_sequence = ""
|
| 105 |
-
else:
|
| 106 |
-
protein_sequence = request.form['protein_sequence']
|
| 107 |
-
drug_sequence = request.form['drug_sequence']
|
| 108 |
-
|
| 109 |
-
dataset = [(protein_sequence, drug_sequence, 1)]
|
| 110 |
-
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn_batch_encoding)
|
| 111 |
-
|
| 112 |
-
case_features = get_case_feature(encoding, dataloader, device)
|
| 113 |
-
model = FusionDTI(446, 768, args).to(device)
|
| 114 |
-
|
| 115 |
-
best_model_dir = f"{args.save_path_prefix}{args.dataset}_{args.fusion}"
|
| 116 |
-
checkpoint_path = os.path.join(best_model_dir, 'best_model.ckpt')
|
| 117 |
-
|
| 118 |
-
if os.path.exists(checkpoint_path):
|
| 119 |
-
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| 120 |
-
|
| 121 |
-
html_file_path = visualize_attention(model, case_features, device, prot_tokenizer, drug_tokenizer)
|
| 122 |
-
|
| 123 |
-
with open(html_file_path, 'r') as f:
|
| 124 |
-
result = f.read()
|
| 125 |
-
|
| 126 |
-
return render_template_string('''
|
| 127 |
-
<html>
|
| 128 |
-
<head>
|
| 129 |
-
<title>Drug Target Interaction Visualization</title>
|
| 130 |
-
<style>
|
| 131 |
-
body { font-family: 'Times New Roman', Times, serif; margin: 40px; }
|
| 132 |
-
h2 { color: #333; }
|
| 133 |
-
.container { display: flex; }
|
| 134 |
-
.left { flex: 1; padding-right: 20px; }
|
| 135 |
-
.right { flex: 1; }
|
| 136 |
-
textarea {
|
| 137 |
-
width: 100%;
|
| 138 |
-
padding: 12px 20px;
|
| 139 |
-
margin: 8px 0;
|
| 140 |
-
display: inline-block;
|
| 141 |
-
border: 1px solid #ccc;
|
| 142 |
-
border-radius: 4px;
|
| 143 |
-
box-sizing: border-box;
|
| 144 |
-
font-size: 16px;
|
| 145 |
-
font-family: 'Times New Roman', Times, serif;
|
| 146 |
-
}
|
| 147 |
-
.button-container {
|
| 148 |
-
display: flex;
|
| 149 |
-
justify-content: space-between;
|
| 150 |
-
}
|
| 151 |
-
input[type="submit"], .button {
|
| 152 |
-
width: 48%;
|
| 153 |
-
color: white;
|
| 154 |
-
padding: 14px 20px;
|
| 155 |
-
margin: 8px 0;
|
| 156 |
-
border: none;
|
| 157 |
-
border-radius: 4px;
|
| 158 |
-
cursor: pointer;
|
| 159 |
-
font-size: 16px;
|
| 160 |
-
font-family: 'Times New Roman', Times, serif;
|
| 161 |
-
}
|
| 162 |
-
.submit {
|
| 163 |
-
background-color: #FFA500;
|
| 164 |
-
}
|
| 165 |
-
.submit:hover {
|
| 166 |
-
background-color: #FF8C00;
|
| 167 |
-
}
|
| 168 |
-
.clear {
|
| 169 |
-
background-color: #D3D3D3;
|
| 170 |
-
}
|
| 171 |
-
.clear:hover {
|
| 172 |
-
background-color: #A9A9A9;
|
| 173 |
-
}
|
| 174 |
-
.result {
|
| 175 |
-
font-size: 18px;
|
| 176 |
-
}
|
| 177 |
-
</style>
|
| 178 |
-
</head>
|
| 179 |
-
<body>
|
| 180 |
-
<h2 style="text-align: center;">Drug Target Interaction Visualization</h2>
|
| 181 |
-
<div class="container">
|
| 182 |
-
<div class="left">
|
| 183 |
-
<form method="post">
|
| 184 |
-
<label for="protein_sequence">Protein Sequence:</label>
|
| 185 |
-
<textarea id="protein_sequence" name="protein_sequence" rows="4" placeholder="Enter protein sequence here..." required>{{ protein_sequence }}</textarea><br>
|
| 186 |
-
<label for="drug_sequence">Drug Sequence:</label>
|
| 187 |
-
<textarea id="drug_sequence" name="drug_sequence" rows="4" placeholder="Enter drug sequence here..." required>{{ drug_sequence }}</textarea><br>
|
| 188 |
-
<div class="button-container">
|
| 189 |
-
<input type="submit" name="submit" class="button submit" value="Submit">
|
| 190 |
-
<input type="submit" name="clear" class="button clear" value="Clear">
|
| 191 |
-
</div>
|
| 192 |
-
</form>
|
| 193 |
-
</div>
|
| 194 |
-
<div class="right" style="display: flex; justify-content: center; align-items: center;">
|
| 195 |
-
{% if result %}
|
| 196 |
-
<div class="result">
|
| 197 |
-
{{ result|safe }}
|
| 198 |
-
</div>
|
| 199 |
-
{% endif %}
|
| 200 |
-
</div>
|
| 201 |
-
</div>
|
| 202 |
-
</body>
|
| 203 |
-
</html>
|
| 204 |
-
''', protein_sequence=protein_sequence, drug_sequence=drug_sequence, result=result)
|
| 205 |
-
|
| 206 |
-
def collate_fn_batch_encoding(batch):
|
| 207 |
query1, query2, scores = zip(*batch)
|
| 208 |
|
| 209 |
query_encodings1 = prot_tokenizer.batch_encode_plus(
|
|
@@ -228,6 +85,388 @@ def collate_fn_batch_encoding(batch):
|
|
| 228 |
attention_mask2 = query_encodings2["attention_mask"].bool()
|
| 229 |
|
| 230 |
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
-
|
|
|
|
| 233 |
app.run(debug=True, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
+
import os, sys, argparse, tempfile, shutil, base64, io
|
| 2 |
+
from flask import Flask, request, render_template_string
|
| 3 |
+
from werkzeug.utils import secure_filename
|
|
|
|
| 4 |
from torch.utils.data import DataLoader
|
| 5 |
+
import selfies
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use("Agg")
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
from matplotlib import cm
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
from utils.drug_tokenizer import DrugTokenizer
|
| 16 |
+
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
|
| 17 |
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
|
| 18 |
+
from utils.foldseek_util import get_struc_seq
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# ───── Biopython fallback ───────────────────────────────────────
|
| 21 |
+
from Bio.PDB import PDBParser, MMCIFParser
|
| 22 |
+
from Bio.Data import IUPACData
|
| 23 |
|
| 24 |
+
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
|
| 25 |
+
three2one.update({"SEC": "C", "PYL": "K"})
|
| 26 |
+
def simple_seq_from_structure(path: str) -> str:
|
| 27 |
+
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
|
| 28 |
+
chain = next(parser.get_structure("P", path).get_chains())
|
| 29 |
+
return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
|
| 30 |
+
|
| 31 |
+
# ───── global paths / args ──────────────────────────────────────
|
| 32 |
+
FOLDSEEK_BIN = shutil.which("foldseek")
|
| 33 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 34 |
+
sys.path.append("..")
|
| 35 |
|
| 36 |
def parse_config():
|
| 37 |
+
p = argparse.ArgumentParser()
|
| 38 |
+
p.add_argument("-f")
|
| 39 |
+
p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
|
| 40 |
+
p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
|
| 41 |
+
p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
|
| 42 |
+
p.add_argument("--group_size", type=int, default=1)
|
| 43 |
+
p.add_argument("--lr", type=float, default=1e-4)
|
| 44 |
+
p.add_argument("--fusion", default="CAN")
|
| 45 |
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 46 |
+
p.add_argument("--save_path_prefix", default="save_model_ckp/")
|
| 47 |
+
p.add_argument("--dataset", default="BindingDB"), help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')"
|
| 48 |
+
return p.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
args = parse_config()
|
| 51 |
+
DEVICE = args.device
|
| 52 |
|
| 53 |
+
# ───── tokenisers & encoders ────────────────────────────────────
|
| 54 |
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
|
| 55 |
+
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
drug_tokenizer = DrugTokenizer() # SELFIES
|
| 58 |
+
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
|
| 59 |
|
| 60 |
+
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
# ─── collate fn ────────────────────────────────────────────────
|
| 63 |
+
def collate_fn(batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
query1, query2, scores = zip(*batch)
|
| 65 |
|
| 66 |
query_encodings1 = prot_tokenizer.batch_encode_plus(
|
|
|
|
| 85 |
attention_mask2 = query_encodings2["attention_mask"].bool()
|
| 86 |
|
| 87 |
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
|
| 88 |
+
# def collate_fn_batch_encoding(batch):
|
| 89 |
+
|
| 90 |
+
def smiles_to_selfies(smiles: str) -> Optional[str]:
|
| 91 |
+
try:
|
| 92 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 93 |
+
if mol is None:
|
| 94 |
+
return None
|
| 95 |
+
selfies_str = selfies.encoder(smiles)
|
| 96 |
+
return selfies_str
|
| 97 |
+
except Exception:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ───── single-case embedding ───────────────────────────────────
|
| 102 |
+
def get_case_feature(model, loader):
|
| 103 |
+
model.eval()
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
for p_ids, p_mask, d_ids, d_mask, _ in loader:
|
| 106 |
+
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
|
| 107 |
+
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
|
| 108 |
+
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
|
| 109 |
+
return [(p_emb.cpu(), d_emb.cpu(),
|
| 110 |
+
p_ids.cpu(), d_ids.cpu(),
|
| 111 |
+
p_mask.cpu(), d_mask.cpu(), None)]
|
| 112 |
+
|
| 113 |
+
# ───── helper:过滤特殊 token ───────────────────────────────────
|
| 114 |
+
def clean_tokens(ids, tokenizer):
|
| 115 |
+
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
|
| 116 |
+
return [t for t in toks if t not in tokenizer.all_special_tokens]
|
| 117 |
+
|
| 118 |
+
# ───── visualisation ───────────────────────────────────────────
|
| 119 |
+
|
| 120 |
+
def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
|
| 121 |
+
"""
|
| 122 |
+
Render a Protein → Drug cross-attention heat-map and, optionally, a
|
| 123 |
+
Top-20 protein-residue table for a chosen drug-token index.
|
| 124 |
+
|
| 125 |
+
The token index shown on the x-axis (and accepted via *drug_idx*) is **the
|
| 126 |
+
position of that token in the *original* drug sequence**, *after* the
|
| 127 |
+
tokeniser but *before* any pruning or truncation (1-based in the labels,
|
| 128 |
+
0-based for the function argument).
|
| 129 |
+
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
html : str
|
| 133 |
+
Base64-embedded PNG heat-map (+ optional HTML table).
|
| 134 |
+
"""
|
| 135 |
+
model.eval()
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
# ── unpack single-case tensors ───────────────────────────────────────────
|
| 138 |
+
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
|
| 139 |
+
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
|
| 140 |
+
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
|
| 141 |
+
|
| 142 |
+
# ── forward pass: Protein → Drug attention (B, n_p, n_d) ───────────────
|
| 143 |
+
_, att_pd = model(p_emb, d_emb, p_mask, d_mask)
|
| 144 |
+
attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
|
| 145 |
+
|
| 146 |
+
# ── decode tokens (skip special symbols) ────────────────────────────────
|
| 147 |
+
def clean_ids(ids, tokenizer):
|
| 148 |
+
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
|
| 149 |
+
return [t for t in toks if t not in tokenizer.all_special_tokens]
|
| 150 |
+
|
| 151 |
+
# ── decode full sequences + record 1-based indices ──────────────────
|
| 152 |
+
p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
|
| 153 |
+
p_indices_full = list(range(1, len(p_tokens_full) + 1))
|
| 154 |
+
|
| 155 |
+
d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
|
| 156 |
+
d_indices_full = list(range(1, len(d_tokens_full) + 1))
|
| 157 |
+
|
| 158 |
+
# ── safety cut-off to match attn mat size ───────────────────────────────
|
| 159 |
+
p_tokens = p_tokens_full[: attn.size(0)]
|
| 160 |
+
p_indices_full = p_indices_full[: attn.size(0)]
|
| 161 |
+
d_tokens_full = d_tokens_full[: attn.size(1)]
|
| 162 |
+
d_indices_full = d_indices_full[: attn.size(1)]
|
| 163 |
+
attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
|
| 164 |
+
|
| 165 |
+
# ── adaptive sparsity pruning ───────────────────────────────────────────
|
| 166 |
+
thr = attn.max().item() * 0.05
|
| 167 |
+
row_keep = (attn.max(dim=1).values > thr)
|
| 168 |
+
col_keep = (attn.max(dim=0).values > thr)
|
| 169 |
+
|
| 170 |
+
if row_keep.sum() < 3:
|
| 171 |
+
row_keep[:] = True
|
| 172 |
+
if col_keep.sum() < 3:
|
| 173 |
+
col_keep[:] = True
|
| 174 |
+
|
| 175 |
+
attn = attn[row_keep][:, col_keep]
|
| 176 |
+
p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
|
| 177 |
+
p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
|
| 178 |
+
d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
|
| 179 |
+
d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
|
| 180 |
+
|
| 181 |
+
# ── cap column count at 150 for readability ─────────────────────────────
|
| 182 |
+
if attn.size(1) > 150:
|
| 183 |
+
topc = torch.topk(attn.sum(0), k=150).indices
|
| 184 |
+
attn = attn[:, topc]
|
| 185 |
+
d_tokens = [d_tokens [i] for i in topc]
|
| 186 |
+
d_indices = [d_indices[i] for i in topc]
|
| 187 |
+
|
| 188 |
+
# ── draw heat-map ───────────────────────────────────────────────────────
|
| 189 |
+
x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
|
| 190 |
+
y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column
|
| 194 |
+
fig_h = min(24, max(6, len(p_tokens) * 0.8))
|
| 195 |
+
|
| 196 |
+
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
|
| 197 |
+
im = ax.imshow(attn.numpy(), aspect="auto",
|
| 198 |
+
cmap=cm.viridis, interpolation="nearest")
|
| 199 |
+
|
| 200 |
+
ax.set_title("Protein → Drug Attention", pad=8, fontsize=10)
|
| 201 |
+
|
| 202 |
+
ax.set_xticks(range(len(x_labels)))
|
| 203 |
+
ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
|
| 204 |
+
ha="center", va="center")
|
| 205 |
+
ax.tick_params(axis="x", top=True, bottom=False,
|
| 206 |
+
labeltop=True, labelbottom=False, pad=27)
|
| 207 |
+
|
| 208 |
+
ax.set_yticks(range(len(y_labels)))
|
| 209 |
+
ax.set_yticklabels(y_labels, fontsize=7)
|
| 210 |
+
ax.tick_params(axis="y", top=True, bottom=False,
|
| 211 |
+
labeltop=True, labelbottom=False,
|
| 212 |
+
pad=10)
|
| 213 |
+
|
| 214 |
+
fig.colorbar(im, fraction=0.026, pad=0.01)
|
| 215 |
+
fig.tight_layout()
|
| 216 |
+
|
| 217 |
+
buf = io.BytesIO()
|
| 218 |
+
fig.savefig(buf, format="png", dpi=140)
|
| 219 |
+
plt.close(fig)
|
| 220 |
+
html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
|
| 221 |
+
|
| 222 |
+
# ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
|
| 223 |
+
table_html = "" # 先设空串,方便后面统一拼接
|
| 224 |
+
if drug_idx is not None:
|
| 225 |
+
# map original 0-based drug_idx → current column position
|
| 226 |
+
if (drug_idx + 1) in d_indices:
|
| 227 |
+
col_pos = d_indices.index(drug_idx + 1)
|
| 228 |
+
elif 0 <= drug_idx < len(d_tokens):
|
| 229 |
+
col_pos = drug_idx
|
| 230 |
+
else:
|
| 231 |
+
col_pos = None
|
| 232 |
+
|
| 233 |
+
if col_pos is not None:
|
| 234 |
+
col_vec = attn[:, col_pos]
|
| 235 |
+
topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
|
| 236 |
+
|
| 237 |
+
rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
|
| 238 |
+
res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
|
| 239 |
+
pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk)
|
| 240 |
+
|
| 241 |
+
drug_tok_text = d_tokens[col_pos]
|
| 242 |
+
orig_idx = d_indices[col_pos]
|
| 243 |
+
|
| 244 |
+
table_html = (
|
| 245 |
+
f"<h4 style='margin-bottom:6px'>"
|
| 246 |
+
f"Drug token #{orig_idx} <code>{drug_tok_text}</code> "
|
| 247 |
+
f"→ Top-20 Protein residues</h4>"
|
| 248 |
+
"<table class='tg' style='margin-bottom:8px'>"
|
| 249 |
+
f"<tr><th>Rank</th>{rank_hdr}</tr>"
|
| 250 |
+
f"<tr><td>Residue</td>{res_row}</tr>"
|
| 251 |
+
f"<tr><td>Position</td>{pos_row}</tr>"
|
| 252 |
+
"</table>")
|
| 253 |
+
|
| 254 |
+
# ────────────────── 生成可放大 + 可下载的热图 ────────────────────
|
| 255 |
+
buf_png = io.BytesIO()
|
| 256 |
+
fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅)
|
| 257 |
+
buf_png.seek(0)
|
| 258 |
+
|
| 259 |
+
buf_pdf = io.BytesIO()
|
| 260 |
+
fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量)
|
| 261 |
+
buf_pdf.seek(0)
|
| 262 |
+
plt.close(fig)
|
| 263 |
+
|
| 264 |
+
png_b64 = base64.b64encode(buf_png.getvalue()).decode()
|
| 265 |
+
pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
|
| 266 |
+
|
| 267 |
+
html_heat = (
|
| 268 |
+
f"<a href='data:image/png;base64,{png_b64}' target='_blank' "
|
| 269 |
+
f"title='Click to enlarge'>"
|
| 270 |
+
f"<img src='data:image/png;base64,{png_b64}' "
|
| 271 |
+
f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>"
|
| 272 |
+
f"<div style='margin-top:6px'>"
|
| 273 |
+
f"<a href='data:application/pdf;base64,{pdf_b64}' "
|
| 274 |
+
f"download='attention_heatmap.pdf'>Download PDF</a></div>"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# ───────────────────────── 返回最终 HTML ─────────────────────────
|
| 278 |
+
return table_html + html_heat
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ───── Flask app ───────────────────────────────────────────────
|
| 282 |
+
app = Flask(__name__)
|
| 283 |
+
|
| 284 |
+
@app.route("/", methods=["GET", "POST"])
|
| 285 |
+
def index():
|
| 286 |
+
protein_seq = drug_seq = structure_seq = ""; result_html = None
|
| 287 |
+
tmp_structure_path = ""; drug_idx = None
|
| 288 |
+
|
| 289 |
+
if request.method == "POST":
|
| 290 |
+
drug_idx_raw = request.form.get("drug_idx", "")
|
| 291 |
+
drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None
|
| 292 |
+
|
| 293 |
+
struct = request.files.get("structure_file")
|
| 294 |
+
if struct and struct.filename:
|
| 295 |
+
path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename))
|
| 296 |
+
struct.save(path); tmp_structure_path = path
|
| 297 |
+
else:
|
| 298 |
+
tmp_structure_path = request.form.get("tmp_structure_path", "")
|
| 299 |
+
|
| 300 |
+
if "clear" in request.form:
|
| 301 |
+
protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = ""
|
| 302 |
+
|
| 303 |
+
elif "confirm_structure" in request.form and tmp_structure_path:
|
| 304 |
+
try:
|
| 305 |
+
parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, None, plddt_mask=False)
|
| 306 |
+
chain = list(parsed.keys())[0]; _, _, structure_seq = parsed[chain]
|
| 307 |
+
except Exception:
|
| 308 |
+
structure_seq = simple_seq_from_structure(tmp_structure_path)
|
| 309 |
+
protein_seq = structure_seq
|
| 310 |
+
drug_input = request.form.get("drug_sequence", "")
|
| 311 |
+
# Heuristically check if input is SMILES (not starting with [) and convert
|
| 312 |
+
if not drug_input.strip().startswith("["):
|
| 313 |
+
converted = smiles_to_selfies(drug_input.strip())
|
| 314 |
+
if converted:
|
| 315 |
+
drug_seq = converted
|
| 316 |
+
else:
|
| 317 |
+
drug_seq = ""
|
| 318 |
+
result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>"
|
| 319 |
+
else:
|
| 320 |
+
drug_seq = drug_input
|
| 321 |
+
|
| 322 |
+
elif "Inference" in request.form:
|
| 323 |
+
protein_seq = request.form.get("protein_sequence", "")
|
| 324 |
+
drug_seq = request.form.get("drug_sequence", "")
|
| 325 |
+
if protein_seq and drug_seq:
|
| 326 |
+
loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1,
|
| 327 |
+
collate_fn=collate_fn)
|
| 328 |
+
feats = get_case_feature(encoding, loader)
|
| 329 |
+
model = FusionDTI(446, 768, args).to(DEVICE)
|
| 330 |
+
ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}",
|
| 331 |
+
"best_model.ckpt")
|
| 332 |
+
if os.path.isfile(ckpt):
|
| 333 |
+
model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
|
| 334 |
+
result_html = visualize_attention(model, feats, drug_idx)
|
| 335 |
+
|
| 336 |
+
return render_template_string(
|
| 337 |
+
# ───────────── HTML (原 UI + 新输入框) ─────────────
|
| 338 |
+
"""
|
| 339 |
+
<!doctype html>
|
| 340 |
+
<html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title>
|
| 341 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet">
|
| 342 |
+
|
| 343 |
+
<style>
|
| 344 |
+
:root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;}
|
| 345 |
+
*{box-sizing:border-box;margin:0;padding:0}
|
| 346 |
+
body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;}
|
| 347 |
+
h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;}
|
| 348 |
+
.card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border);
|
| 349 |
+
border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;}
|
| 350 |
+
label{font-weight:500;margin-bottom:6px;display:block}
|
| 351 |
+
textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px;
|
| 352 |
+
border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;}
|
| 353 |
+
textarea{min-height:90px}
|
| 354 |
+
.btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500;
|
| 355 |
+
font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;}
|
| 356 |
+
.btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)}
|
| 357 |
+
.btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280}
|
| 358 |
+
.grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr}
|
| 359 |
+
.vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;}
|
| 360 |
+
pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px}
|
| 361 |
+
|
| 362 |
+
/* ── tidy table for Top-20 list ─────────────────────────────── */
|
| 363 |
+
table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem}
|
| 364 |
+
table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left}
|
| 365 |
+
table.tg th{background:var(--bg);font-weight:600}
|
| 366 |
+
</style>
|
| 367 |
+
</head>
|
| 368 |
+
<body>
|
| 369 |
+
<h1> Token-level Visualiser for Drug-Target Interaction</h1>
|
| 370 |
+
|
| 371 |
+
<!-- ───────────── Project Links (larger + spaced) ───────────── -->
|
| 372 |
+
<div style="margin-top:24px; text-align:center;">
|
| 373 |
+
<a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank"
|
| 374 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
| 375 |
+
background:linear-gradient(to right,#10b981,#059669);color:white;
|
| 376 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
| 377 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
| 378 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
| 379 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
| 380 |
+
🌐 Project Page
|
| 381 |
+
</a>
|
| 382 |
+
|
| 383 |
+
<a href="https://arxiv.org/abs/2406.01651" target="_blank"
|
| 384 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
| 385 |
+
background:linear-gradient(to right,#ef4444,#dc2626);color:white;
|
| 386 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
| 387 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
| 388 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
| 389 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
| 390 |
+
📄 ArXiv: 2406.01651
|
| 391 |
+
</a>
|
| 392 |
+
|
| 393 |
+
<a href="https://github.com/ZhaohanM/FusionDTI" target="_blank"
|
| 394 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
| 395 |
+
background:linear-gradient(to right,#3b82f6,#2563eb);color:white;
|
| 396 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
| 397 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
| 398 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
| 399 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
| 400 |
+
💻 GitHub Repo
|
| 401 |
+
</a>
|
| 402 |
+
</div>
|
| 403 |
+
|
| 404 |
+
<!-- ───────────── Guidelines for Use ───────────── -->
|
| 405 |
+
<div class="card" style="margin-bottom:24px">
|
| 406 |
+
<h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2>
|
| 407 |
+
<ul style="margin-left:18px;line-height:1.55;list-style:decimal;">
|
| 408 |
+
<li><strong>Convert protein structure into a structure-aware sequence:</strong>
|
| 409 |
+
Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
|
| 410 |
+
sequence will be generated using
|
| 411 |
+
<a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
|
| 412 |
+
based on 3D structures from
|
| 413 |
+
<a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> or the
|
| 414 |
+
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
|
| 415 |
+
|
| 416 |
+
<li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
|
| 417 |
+
you must first visit the
|
| 418 |
+
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
|
| 419 |
+
or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a>
|
| 420 |
+
to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
|
| 421 |
+
|
| 422 |
+
<li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
|
| 423 |
+
You can enter a SELFIES string directly, or paste a SMILES string.
|
| 424 |
+
SMILES will be automatically converted to SELFIES using
|
| 425 |
+
<a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
|
| 426 |
+
If conversion fails, a red error message will be displayed.</li>
|
| 427 |
+
|
| 428 |
+
<li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
|
| 429 |
+
to highlight the Top-10 interacting protein residues.</li>
|
| 430 |
+
|
| 431 |
+
<li>After inference, you can use the
|
| 432 |
+
“Download PDF” link to export a high-resolution vector version.</li>
|
| 433 |
+
</ul>
|
| 434 |
+
</div>
|
| 435 |
+
|
| 436 |
+
<div class="card">
|
| 437 |
+
<form method="POST" enctype="multipart/form-data" class="grid">
|
| 438 |
+
|
| 439 |
+
<div><label>Protein Structure (.pdb / .cif)</label>
|
| 440 |
+
<input type="file" name="structure_file">
|
| 441 |
+
<input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div>
|
| 442 |
+
|
| 443 |
+
<div><label>Protein Sequence</label>
|
| 444 |
+
<textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div>
|
| 445 |
+
|
| 446 |
+
<div><label>Drug Sequence (SELFIES/SMILES)</label>
|
| 447 |
+
<textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div>
|
| 448 |
+
|
| 449 |
+
<label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label>
|
| 450 |
+
<input type="number" name="drug_idx" min="1" style="width:120px">
|
| 451 |
+
|
| 452 |
+
<div class="grid grid-2">
|
| 453 |
+
<button class="btn btn-primary" type="Inference" name="confirm_structure">Confirm Structure</button>
|
| 454 |
+
<button class="btn btn-primary" type="Inference" name="Inference">Inference</button>
|
| 455 |
+
</div>
|
| 456 |
+
<button class="btn btn-neutral" style="width:100%" type="Inference" name="clear">Clear</button>
|
| 457 |
+
</form>
|
| 458 |
+
|
| 459 |
+
{% if structure_seq %}
|
| 460 |
+
<div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div>
|
| 461 |
+
{% endif %}
|
| 462 |
+
{% if result_html %}
|
| 463 |
+
<div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div>
|
| 464 |
+
{% endif %}
|
| 465 |
+
</div></body></html>
|
| 466 |
+
""",
|
| 467 |
+
protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq,
|
| 468 |
+
result_html=result_html, tmp_structure_path=tmp_structure_path)
|
| 469 |
|
| 470 |
+
# ───── run ─────────────────────────────────────────────────────
|
| 471 |
+
if __name__ == "__main__":
|
| 472 |
app.run(debug=True, host="0.0.0.0", port=7860)
|
requirements.txt
CHANGED
|
@@ -1,5 +1,11 @@
|
|
| 1 |
Flask
|
| 2 |
torch
|
| 3 |
transformers
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
Flask
|
| 2 |
torch
|
| 3 |
transformers
|
| 4 |
+
IPython
|
| 5 |
+
selfies
|
| 6 |
+
rdkit
|
| 7 |
+
biopython
|
| 8 |
+
matplotlib
|
| 9 |
+
scikit-learn
|
| 10 |
+
numpy
|
| 11 |
+
pandas
|
utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
class DrugTokenizer:
|
| 8 |
+
def __init__(self, vocab_path="data/Tokenizer/vocab.json", special_tokens_path="data/Tokenizer/special_tokens_map.json"):
|
| 9 |
+
self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
|
| 10 |
+
self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
|
| 11 |
+
self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
|
| 12 |
+
self.unk_token_id = self.vocab[self.special_tokens['unk_token']]
|
| 13 |
+
self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
|
| 14 |
+
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
| 15 |
+
|
| 16 |
+
self.all_special_tokens = list(self.special_tokens.values())
|
| 17 |
+
|
| 18 |
+
def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
|
| 19 |
+
with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
|
| 20 |
+
vocab = json.load(vocab_file)
|
| 21 |
+
with open(special_tokens_path, 'r', encoding='utf-8') as special_tokens_file:
|
| 22 |
+
special_tokens_raw = json.load(special_tokens_file)
|
| 23 |
+
|
| 24 |
+
special_tokens = {key: value['content'] for key, value in special_tokens_raw.items()}
|
| 25 |
+
return vocab, special_tokens
|
| 26 |
+
|
| 27 |
+
def encode(self, sequence):
|
| 28 |
+
tokens = re.findall(r'\[([^\[\]]+)\]', sequence)
|
| 29 |
+
input_ids = [self.cls_token_id] + [self.vocab.get(token, self.unk_token_id) for token in tokens] + [self.sep_token_id]
|
| 30 |
+
attention_mask = [1] * len(input_ids)
|
| 31 |
+
return {
|
| 32 |
+
'input_ids': input_ids,
|
| 33 |
+
'attention_mask': attention_mask
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def batch_encode_plus(self, sequences, max_length, padding, truncation, add_special_tokens, return_tensors):
|
| 37 |
+
input_ids_list = []
|
| 38 |
+
attention_mask_list = []
|
| 39 |
+
|
| 40 |
+
for sequence in sequences:
|
| 41 |
+
encoded = self.encode(sequence)
|
| 42 |
+
input_ids = encoded['input_ids']
|
| 43 |
+
attention_mask = encoded['attention_mask']
|
| 44 |
+
|
| 45 |
+
if len(input_ids) > max_length:
|
| 46 |
+
input_ids = input_ids[:max_length]
|
| 47 |
+
attention_mask = attention_mask[:max_length]
|
| 48 |
+
elif len(input_ids) < max_length:
|
| 49 |
+
pad_length = max_length - len(input_ids)
|
| 50 |
+
input_ids = input_ids + [self.vocab[self.special_tokens['pad_token']]] * pad_length
|
| 51 |
+
attention_mask = attention_mask + [0] * pad_length
|
| 52 |
+
|
| 53 |
+
input_ids_list.append(input_ids)
|
| 54 |
+
attention_mask_list.append(attention_mask)
|
| 55 |
+
|
| 56 |
+
return {
|
| 57 |
+
'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
|
| 58 |
+
'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long)
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def decode(self, input_ids, skip_special_tokens=False):
|
| 62 |
+
tokens = []
|
| 63 |
+
for id in input_ids:
|
| 64 |
+
if skip_special_tokens and id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
|
| 65 |
+
continue
|
| 66 |
+
tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
|
| 67 |
+
sequence = ''.join([f'[{token}]' for token in tokens])
|
| 68 |
+
return sequence
|
| 69 |
+
|
| 70 |
+
# --- 新增 ---
|
| 71 |
+
def convert_ids_to_tokens(self, ids):
|
| 72 |
+
"""list[int] → list[str],跳过未知 id"""
|
| 73 |
+
return [self.id_to_token.get(i, self.special_tokens['unk_token']) for i in ids]
|
utils/.ipynb_checkpoints/metric_learning_models_att_maps-checkpoint.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
sys.path.append("../")
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.cuda.amp import autocast
|
| 11 |
+
from torch.nn import Module
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from torch.nn.utils.weight_norm import weight_norm
|
| 14 |
+
from torch.utils.data import Dataset
|
| 15 |
+
|
| 16 |
+
LOGGER = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class FusionDTI(nn.Module):
|
| 19 |
+
def __init__(self, prot_out_dim, disease_out_dim, args):
|
| 20 |
+
super(FusionDTI, self).__init__()
|
| 21 |
+
self.fusion = args.fusion
|
| 22 |
+
self.drug_reg = nn.Linear(disease_out_dim, 512)
|
| 23 |
+
self.prot_reg = nn.Linear(prot_out_dim, 512)
|
| 24 |
+
|
| 25 |
+
if self.fusion == "CAN":
|
| 26 |
+
self.can_layer = CAN_Layer(hidden_dim=512, num_heads=8, args=args)
|
| 27 |
+
self.mlp_classifier = MlPdecoder_CAN(input_dim=1024)
|
| 28 |
+
elif self.fusion == "BAN":
|
| 29 |
+
self.ban_layer = weight_norm(BANLayer(512, 512, 256, 2), name='h_mat', dim=None)
|
| 30 |
+
self.mlp_classifier = MlPdecoder_CAN(input_dim=256)
|
| 31 |
+
elif self.fusion == "Nan":
|
| 32 |
+
self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214)
|
| 33 |
+
|
| 34 |
+
def forward(self, prot_embed, drug_embed, prot_mask, drug_mask):
|
| 35 |
+
# print("drug_embed", drug_embed.shape)
|
| 36 |
+
if self.fusion == "Nan":
|
| 37 |
+
prot_embed = prot_embed.mean(1) # query : [batch_size, hidden]
|
| 38 |
+
drug_embed = drug_embed.mean(1) # query : [batch_size, hidden]
|
| 39 |
+
joint_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
| 40 |
+
score = self.mlp_classifier_nan(joint_embed)
|
| 41 |
+
else:
|
| 42 |
+
prot_embed = self.prot_reg(prot_embed)
|
| 43 |
+
drug_embed = self.drug_reg(drug_embed)
|
| 44 |
+
|
| 45 |
+
if self.fusion == "CAN":
|
| 46 |
+
joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask)
|
| 47 |
+
elif self.fusion == "BAN":
|
| 48 |
+
joint_embed, att = self.ban_layer(prot_embed, drug_embed)
|
| 49 |
+
|
| 50 |
+
score = self.mlp_classifier(joint_embed)
|
| 51 |
+
|
| 52 |
+
return score, att
|
| 53 |
+
|
| 54 |
+
class Pre_encoded(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self, prot_encoder, drug_encoder, args
|
| 57 |
+
):
|
| 58 |
+
"""Constructor for the model.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
prot_encoder (_type_): Protein sturcture-aware sequence encoder.
|
| 62 |
+
drug_encoder (_type_): Drug SFLFIES encoder.
|
| 63 |
+
args (_type_): _description_
|
| 64 |
+
"""
|
| 65 |
+
super(Pre_encoded, self).__init__()
|
| 66 |
+
self.prot_encoder = prot_encoder
|
| 67 |
+
self.drug_encoder = drug_encoder
|
| 68 |
+
|
| 69 |
+
def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask):
|
| 70 |
+
# Process inputs through encoders
|
| 71 |
+
prot_embed = self.prot_encoder(
|
| 72 |
+
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
|
| 73 |
+
).logits
|
| 74 |
+
# prot_embed = self.prot_reg(prot_embed)
|
| 75 |
+
|
| 76 |
+
drug_embed = self.drug_encoder(
|
| 77 |
+
input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True
|
| 78 |
+
).last_hidden_state # .last_hidden_state
|
| 79 |
+
|
| 80 |
+
# print("drug_embed", drug_embed.shape)
|
| 81 |
+
|
| 82 |
+
return prot_embed, drug_embed
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CAN_Layer(nn.Module):
|
| 86 |
+
def __init__(self, hidden_dim, num_heads, args):
|
| 87 |
+
super(CAN_Layer, self).__init__()
|
| 88 |
+
self.agg_mode = args.agg_mode
|
| 89 |
+
self.group_size = args.group_size # Control Fusion Scale
|
| 90 |
+
self.hidden_dim = hidden_dim
|
| 91 |
+
self.num_heads = num_heads
|
| 92 |
+
self.head_size = hidden_dim // num_heads
|
| 93 |
+
|
| 94 |
+
self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 95 |
+
self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 96 |
+
self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 97 |
+
|
| 98 |
+
self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 99 |
+
self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 100 |
+
self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 101 |
+
|
| 102 |
+
def alpha_logits(self, logits, mask_row, mask_col, inf=1e6):
|
| 103 |
+
N, L1, L2, H = logits.shape
|
| 104 |
+
mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
|
| 105 |
+
mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
|
| 106 |
+
mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
|
| 107 |
+
|
| 108 |
+
logits = torch.where(mask_pair, logits, logits - inf)
|
| 109 |
+
alpha = torch.softmax(logits, dim=2)
|
| 110 |
+
mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
|
| 111 |
+
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
|
| 112 |
+
return alpha
|
| 113 |
+
|
| 114 |
+
def apply_heads(self, x, n_heads, n_ch):
|
| 115 |
+
s = list(x.size())[:-1] + [n_heads, n_ch]
|
| 116 |
+
return x.view(*s)
|
| 117 |
+
|
| 118 |
+
def group_embeddings(self, x, mask, group_size):
|
| 119 |
+
N, L, D = x.shape
|
| 120 |
+
groups = L // group_size
|
| 121 |
+
x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
|
| 122 |
+
mask_grouped = mask.view(N, groups, group_size).any(dim=2)
|
| 123 |
+
return x_grouped, mask_grouped
|
| 124 |
+
|
| 125 |
+
def forward(self, protein, drug, mask_prot, mask_drug):
|
| 126 |
+
# Group embeddings before applying multi-head attention
|
| 127 |
+
protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size)
|
| 128 |
+
drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size)
|
| 129 |
+
|
| 130 |
+
# print("protein_grouped:", protein_grouped.shape)
|
| 131 |
+
# print("mask_prot_grouped:", mask_prot_grouped.shape)
|
| 132 |
+
|
| 133 |
+
# Compute queries, keys, values for both protein and drug after grouping
|
| 134 |
+
query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size)
|
| 135 |
+
key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size)
|
| 136 |
+
value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size)
|
| 137 |
+
|
| 138 |
+
query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size)
|
| 139 |
+
key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size)
|
| 140 |
+
value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size)
|
| 141 |
+
|
| 142 |
+
# Compute attention scores
|
| 143 |
+
logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
|
| 144 |
+
logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug)
|
| 145 |
+
logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot)
|
| 146 |
+
logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug)
|
| 147 |
+
# print("logits_pp:", logits_pp.shape)
|
| 148 |
+
|
| 149 |
+
alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
|
| 150 |
+
alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped)
|
| 151 |
+
alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped)
|
| 152 |
+
alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped)
|
| 153 |
+
|
| 154 |
+
prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) +
|
| 155 |
+
torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2
|
| 156 |
+
drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) +
|
| 157 |
+
torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2
|
| 158 |
+
|
| 159 |
+
# print("prot_embedding:", prot_embedding.shape)
|
| 160 |
+
|
| 161 |
+
# Continue as usual with the aggregation mode
|
| 162 |
+
if self.agg_mode == "cls":
|
| 163 |
+
prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
|
| 164 |
+
drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden]
|
| 165 |
+
elif self.agg_mode == "mean_all_tok":
|
| 166 |
+
prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
|
| 167 |
+
drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden]
|
| 168 |
+
elif self.agg_mode == "mean":
|
| 169 |
+
prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
|
| 170 |
+
drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1)
|
| 171 |
+
else:
|
| 172 |
+
raise NotImplementedError()
|
| 173 |
+
|
| 174 |
+
# print("prot_embed:", prot_embed.shape)
|
| 175 |
+
|
| 176 |
+
query_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
| 177 |
+
|
| 178 |
+
att_pd = alpha_pd.mean(dim=-1)
|
| 179 |
+
|
| 180 |
+
# print("query_embed:", query_embed.shape)
|
| 181 |
+
return query_embed, att_pd
|
| 182 |
+
|
| 183 |
+
class MlPdecoder_CAN(nn.Module):
|
| 184 |
+
def __init__(self, input_dim):
|
| 185 |
+
super(MlPdecoder_CAN, self).__init__()
|
| 186 |
+
self.fc1 = nn.Linear(input_dim, input_dim)
|
| 187 |
+
self.bn1 = nn.BatchNorm1d(input_dim)
|
| 188 |
+
self.fc2 = nn.Linear(input_dim, input_dim // 2)
|
| 189 |
+
self.bn2 = nn.BatchNorm1d(input_dim // 2)
|
| 190 |
+
self.fc3 = nn.Linear(input_dim // 2, input_dim // 4)
|
| 191 |
+
self.bn3 = nn.BatchNorm1d(input_dim // 4)
|
| 192 |
+
self.output = nn.Linear(input_dim // 4, 1)
|
| 193 |
+
|
| 194 |
+
def forward(self, x):
|
| 195 |
+
x = self.bn1(torch.relu(self.fc1(x)))
|
| 196 |
+
x = self.bn2(torch.relu(self.fc2(x)))
|
| 197 |
+
x = self.bn3(torch.relu(self.fc3(x)))
|
| 198 |
+
x = torch.sigmoid(self.output(x))
|
| 199 |
+
return x
|
| 200 |
+
|
| 201 |
+
class MLPdecoder_BAN(nn.Module):
|
| 202 |
+
def __init__(self, in_dim, hidden_dim, out_dim, binary=1):
|
| 203 |
+
super(MLPdecoder_BAN, self).__init__()
|
| 204 |
+
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
| 205 |
+
self.bn1 = nn.BatchNorm1d(hidden_dim)
|
| 206 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 207 |
+
self.bn2 = nn.BatchNorm1d(hidden_dim)
|
| 208 |
+
self.fc3 = nn.Linear(hidden_dim, out_dim)
|
| 209 |
+
self.bn3 = nn.BatchNorm1d(out_dim)
|
| 210 |
+
self.fc4 = nn.Linear(out_dim, binary)
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
x = self.bn1(F.relu(self.fc1(x)))
|
| 214 |
+
x = self.bn2(F.relu(self.fc2(x)))
|
| 215 |
+
x = self.bn3(F.relu(self.fc3(x)))
|
| 216 |
+
# x = self.fc4(x)
|
| 217 |
+
x = torch.sigmoid(self.fc4(x))
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
class BANLayer(nn.Module):
|
| 221 |
+
""" Bilinear attention network
|
| 222 |
+
Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py
|
| 223 |
+
"""
|
| 224 |
+
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
|
| 225 |
+
super(BANLayer, self).__init__()
|
| 226 |
+
|
| 227 |
+
self.c = 32
|
| 228 |
+
self.k = k
|
| 229 |
+
self.v_dim = v_dim
|
| 230 |
+
self.q_dim = q_dim
|
| 231 |
+
self.h_dim = h_dim
|
| 232 |
+
self.h_out = h_out
|
| 233 |
+
|
| 234 |
+
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
|
| 235 |
+
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
|
| 236 |
+
# self.dropout = nn.Dropout(dropout[1])
|
| 237 |
+
if 1 < k:
|
| 238 |
+
self.p_net = nn.AvgPool1d(self.k, stride=self.k)
|
| 239 |
+
|
| 240 |
+
if h_out <= self.c:
|
| 241 |
+
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
|
| 242 |
+
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
|
| 243 |
+
else:
|
| 244 |
+
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
|
| 245 |
+
|
| 246 |
+
self.bn = nn.BatchNorm1d(h_dim)
|
| 247 |
+
|
| 248 |
+
def attention_pooling(self, v, q, att_map):
|
| 249 |
+
fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
|
| 250 |
+
if 1 < self.k:
|
| 251 |
+
fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d
|
| 252 |
+
fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling
|
| 253 |
+
return fusion_logits
|
| 254 |
+
|
| 255 |
+
def forward(self, v, q, softmax=False):
|
| 256 |
+
v_num = v.size(1)
|
| 257 |
+
q_num = q.size(1)
|
| 258 |
+
# print("v_num", v_num)
|
| 259 |
+
# print("v_num ", v_num)
|
| 260 |
+
if self.h_out <= self.c:
|
| 261 |
+
v_ = self.v_net(v)
|
| 262 |
+
q_ = self.q_net(q)
|
| 263 |
+
# print("v_", v_.shape)
|
| 264 |
+
# print("q_ ", q_.shape)
|
| 265 |
+
att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
|
| 266 |
+
# print("Attention map_1",att_maps.shape)
|
| 267 |
+
else:
|
| 268 |
+
v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
|
| 269 |
+
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
|
| 270 |
+
d_ = torch.matmul(v_, q_) # b x h_dim x v x q
|
| 271 |
+
att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
|
| 272 |
+
att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
|
| 273 |
+
# print("Attention map_2",att_maps.shape)
|
| 274 |
+
if softmax:
|
| 275 |
+
p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2)
|
| 276 |
+
att_maps = p.view(-1, self.h_out, v_num, q_num)
|
| 277 |
+
# print("Attention map_softmax", att_maps.shape)
|
| 278 |
+
logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
|
| 279 |
+
for i in range(1, self.h_out):
|
| 280 |
+
logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :])
|
| 281 |
+
logits += logits_i
|
| 282 |
+
logits = self.bn(logits)
|
| 283 |
+
return logits, att_maps
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class FCNet(nn.Module):
|
| 287 |
+
"""Simple class for non-linear fully connect network
|
| 288 |
+
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def __init__(self, dims, act='ReLU', dropout=0):
|
| 292 |
+
super(FCNet, self).__init__()
|
| 293 |
+
|
| 294 |
+
layers = []
|
| 295 |
+
for i in range(len(dims) - 2):
|
| 296 |
+
in_dim = dims[i]
|
| 297 |
+
out_dim = dims[i + 1]
|
| 298 |
+
if 0 < dropout:
|
| 299 |
+
layers.append(nn.Dropout(dropout))
|
| 300 |
+
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
|
| 301 |
+
if '' != act:
|
| 302 |
+
layers.append(getattr(nn, act)())
|
| 303 |
+
if 0 < dropout:
|
| 304 |
+
layers.append(nn.Dropout(dropout))
|
| 305 |
+
layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
|
| 306 |
+
if '' != act:
|
| 307 |
+
layers.append(getattr(nn, act)())
|
| 308 |
+
|
| 309 |
+
self.main = nn.Sequential(*layers)
|
| 310 |
+
|
| 311 |
+
def forward(self, x):
|
| 312 |
+
return self.main(x)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class BatchFileDataset_Case(Dataset):
|
| 316 |
+
def __init__(self, file_list):
|
| 317 |
+
self.file_list = file_list
|
| 318 |
+
|
| 319 |
+
def __len__(self):
|
| 320 |
+
return len(self.file_list)
|
| 321 |
+
|
| 322 |
+
def __getitem__(self, idx):
|
| 323 |
+
batch_file = self.file_list[idx]
|
| 324 |
+
data = torch.load(batch_file)
|
| 325 |
+
return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y']
|
utils/__pycache__/foldseek_util.cpython-38.pyc
ADDED
|
Binary file (4.86 kB). View file
|
|
|
utils/__pycache__/metric_learning_models_att_maps.cpython-38.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
utils/drug_tokenizer.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
| 5 |
from torch.nn import functional as F
|
| 6 |
|
| 7 |
class DrugTokenizer:
|
| 8 |
-
def __init__(self, vocab_path="
|
| 9 |
self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
|
| 10 |
self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
|
| 11 |
self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
|
|
@@ -13,6 +13,8 @@ class DrugTokenizer:
|
|
| 13 |
self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
|
| 14 |
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
| 15 |
|
|
|
|
|
|
|
| 16 |
def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
|
| 17 |
with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
|
| 18 |
vocab = json.load(vocab_file)
|
|
@@ -64,3 +66,8 @@ class DrugTokenizer:
|
|
| 64 |
tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
|
| 65 |
sequence = ''.join([f'[{token}]' for token in tokens])
|
| 66 |
return sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from torch.nn import functional as F
|
| 6 |
|
| 7 |
class DrugTokenizer:
|
| 8 |
+
def __init__(self, vocab_path="data/Tokenizer/vocab.json", special_tokens_path="data/Tokenizer/special_tokens_map.json"):
|
| 9 |
self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
|
| 10 |
self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
|
| 11 |
self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
|
|
|
|
| 13 |
self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
|
| 14 |
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
| 15 |
|
| 16 |
+
self.all_special_tokens = list(self.special_tokens.values())
|
| 17 |
+
|
| 18 |
def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
|
| 19 |
with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
|
| 20 |
vocab = json.load(vocab_file)
|
|
|
|
| 66 |
tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
|
| 67 |
sequence = ''.join([f'[{token}]' for token in tokens])
|
| 68 |
return sequence
|
| 69 |
+
|
| 70 |
+
# --- 新增 ---
|
| 71 |
+
def convert_ids_to_tokens(self, ids):
|
| 72 |
+
"""list[int] → list[str],跳过未知 id"""
|
| 73 |
+
return [self.id_to_token.get(i, self.special_tokens['unk_token']) for i in ids]
|
utils/foldseek_util.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import re
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
from Bio.PDB import PDBParser, MMCIFParser
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
sys.path.append(".")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Get structural seqs from pdb file
|
| 15 |
+
def get_struc_seq(foldseek,
|
| 16 |
+
path,
|
| 17 |
+
chains: list = None,
|
| 18 |
+
process_id: int = 0,
|
| 19 |
+
plddt_mask: bool = "auto",
|
| 20 |
+
plddt_threshold: float = 70.,
|
| 21 |
+
foldseek_verbose: bool = False) -> dict:
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
foldseek: Binary executable file of foldseek
|
| 26 |
+
|
| 27 |
+
path: Path to pdb file
|
| 28 |
+
|
| 29 |
+
chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
|
| 30 |
+
|
| 31 |
+
process_id: Process ID for temporary files. This is used for parallel processing.
|
| 32 |
+
|
| 33 |
+
plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
|
| 34 |
+
|
| 35 |
+
plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
|
| 36 |
+
|
| 37 |
+
foldseek_verbose: If True, foldseek will print verbose messages.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
|
| 41 |
+
(seq, struc_seq, combined_seq).
|
| 42 |
+
"""
|
| 43 |
+
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
|
| 44 |
+
assert os.path.exists(path), f"PDB file not found: {path}"
|
| 45 |
+
|
| 46 |
+
tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
|
| 47 |
+
if foldseek_verbose:
|
| 48 |
+
cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
| 49 |
+
else:
|
| 50 |
+
cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
| 51 |
+
os.system(cmd)
|
| 52 |
+
|
| 53 |
+
# Check whether the structure is predicted by AlphaFold2
|
| 54 |
+
if plddt_mask == "auto":
|
| 55 |
+
with open(path, "r") as r:
|
| 56 |
+
plddt_mask = True if "alphafold" in r.read().lower() else False
|
| 57 |
+
|
| 58 |
+
seq_dict = {}
|
| 59 |
+
name = os.path.basename(path)
|
| 60 |
+
with open(tmp_save_path, "r") as r:
|
| 61 |
+
for i, line in enumerate(r):
|
| 62 |
+
desc, seq, struc_seq = line.split("\t")[:3]
|
| 63 |
+
|
| 64 |
+
# Mask low plddt
|
| 65 |
+
if plddt_mask:
|
| 66 |
+
try:
|
| 67 |
+
plddts = extract_plddt(path)
|
| 68 |
+
assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
|
| 69 |
+
|
| 70 |
+
# Mask regions with plddt < threshold
|
| 71 |
+
indices = np.where(plddts < plddt_threshold)[0]
|
| 72 |
+
np_seq = np.array(list(struc_seq))
|
| 73 |
+
np_seq[indices] = "#"
|
| 74 |
+
struc_seq = "".join(np_seq)
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"Error: {e}")
|
| 78 |
+
print(f"Failed to mask plddt for {name}")
|
| 79 |
+
|
| 80 |
+
name_chain = desc.split(" ")[0]
|
| 81 |
+
chain = name_chain.replace(name, "").split("_")[-1]
|
| 82 |
+
|
| 83 |
+
if chains is None or chain in chains:
|
| 84 |
+
if chain not in seq_dict:
|
| 85 |
+
combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
|
| 86 |
+
seq_dict[chain] = (seq, struc_seq, combined_seq)
|
| 87 |
+
|
| 88 |
+
os.remove(tmp_save_path)
|
| 89 |
+
os.remove(tmp_save_path + ".dbtype")
|
| 90 |
+
return seq_dict
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def extract_plddt(pdb_path: str) -> np.ndarray:
|
| 94 |
+
"""
|
| 95 |
+
Extract plddt scores from pdb file.
|
| 96 |
+
Args:
|
| 97 |
+
pdb_path: Path to pdb file.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
plddts: plddt scores.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
# Initialize parser
|
| 104 |
+
if pdb_path.endswith(".cif"):
|
| 105 |
+
parser = MMCIFParser()
|
| 106 |
+
elif pdb_path.endswith(".pdb"):
|
| 107 |
+
parser = PDBParser()
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError("Invalid file format for plddt extraction. Must be '.cif' or '.pdb'.")
|
| 110 |
+
|
| 111 |
+
structure = parser.get_structure('protein', pdb_path)
|
| 112 |
+
model = structure[0]
|
| 113 |
+
chain = model["A"]
|
| 114 |
+
|
| 115 |
+
# Extract plddt scores
|
| 116 |
+
plddts = []
|
| 117 |
+
for residue in chain:
|
| 118 |
+
residue_plddts = []
|
| 119 |
+
for atom in residue:
|
| 120 |
+
plddt = atom.get_bfactor()
|
| 121 |
+
residue_plddts.append(plddt)
|
| 122 |
+
|
| 123 |
+
plddts.append(np.mean(residue_plddts))
|
| 124 |
+
|
| 125 |
+
plddts = np.array(plddts)
|
| 126 |
+
return plddts
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def transform_pdb_dir(foldseek: str, pdb_dir: str, seq_type: str, save_path: str):
|
| 130 |
+
"""
|
| 131 |
+
Transform a directory of pdb files into a fasta file.
|
| 132 |
+
Args:
|
| 133 |
+
foldseek: Binary executable file of foldseek.
|
| 134 |
+
|
| 135 |
+
pdb_dir: Directory of pdb files.
|
| 136 |
+
|
| 137 |
+
seq_type: Type of sequence to be extracted. Must be "aa" or "foldseek"
|
| 138 |
+
|
| 139 |
+
save_path: Path to save the fasta file.
|
| 140 |
+
"""
|
| 141 |
+
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
|
| 142 |
+
assert seq_type in ["aa", "foldseek"], f"seq_type must be 'aa' or 'foldseek'!"
|
| 143 |
+
|
| 144 |
+
tmp_save_path = f"get_struc_seq_{time.time()}.tsv"
|
| 145 |
+
cmd = f"{foldseek} structureto3didescriptor --chain-name-mode 1 {pdb_dir} {tmp_save_path}"
|
| 146 |
+
os.system(cmd)
|
| 147 |
+
|
| 148 |
+
with open(tmp_save_path, "r") as r, open(save_path, "w") as w:
|
| 149 |
+
for line in r:
|
| 150 |
+
protein_id, aa_seq, foldseek_seq = line.strip().split("\t")[:3]
|
| 151 |
+
|
| 152 |
+
if seq_type == "aa":
|
| 153 |
+
w.write(f">{protein_id}\n{aa_seq}\n")
|
| 154 |
+
else:
|
| 155 |
+
w.write(f">{protein_id}\n{foldseek_seq.lower()}\n")
|
| 156 |
+
|
| 157 |
+
os.remove(tmp_save_path)
|
| 158 |
+
os.remove(tmp_save_path + ".dbtype")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
foldseek = "/sujin/bin/foldseek"
|
| 163 |
+
# test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
|
| 164 |
+
test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
|
| 165 |
+
plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
|
| 166 |
+
res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
|
| 167 |
+
print(res["A"][1].lower())
|
utils/metric_learning_models_att_maps.py
CHANGED
|
@@ -175,15 +175,10 @@ class CAN_Layer(nn.Module):
|
|
| 175 |
|
| 176 |
query_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
| 177 |
|
| 178 |
-
|
| 179 |
-
att = torch.zeros(1, 1, 1024, 1024)
|
| 180 |
-
att[:, :, :512, :512] = alpha_pp.mean(dim=-1) # Protein to Protein
|
| 181 |
-
att[:, :, :512, 512:] = alpha_pd.mean(dim=-1) # Protein to Drug
|
| 182 |
-
att[:, :, 512:, :512] = alpha_dp.mean(dim=-1) # Drug to Protein
|
| 183 |
-
att[:, :, 512:, 512:] = alpha_dd.mean(dim=-1) # Drug to Drug
|
| 184 |
|
| 185 |
# print("query_embed:", query_embed.shape)
|
| 186 |
-
return query_embed,
|
| 187 |
|
| 188 |
class MlPdecoder_CAN(nn.Module):
|
| 189 |
def __init__(self, input_dim):
|
|
|
|
| 175 |
|
| 176 |
query_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
| 177 |
|
| 178 |
+
att_pd = alpha_pd.mean(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
# print("query_embed:", query_embed.shape)
|
| 181 |
+
return query_embed, att_pd
|
| 182 |
|
| 183 |
class MlPdecoder_CAN(nn.Module):
|
| 184 |
def __init__(self, input_dim):
|