Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from nnutils import create_var, index_select_ND | |
from chemutils import get_mol | |
import rdkit.Chem as Chem | |
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] | |
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1 | |
BOND_FDIM = 5 | |
MAX_NB = 15 | |
def onek_encoding_unk(x, allowable_set): | |
if x not in allowable_set: | |
x = allowable_set[-1] | |
return list(map(lambda s: x == s, allowable_set)) | |
def atom_features(atom): | |
return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) | |
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) | |
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) | |
+ [atom.GetIsAromatic()]) | |
def bond_features(bond): | |
bt = bond.GetBondType() | |
return torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]) | |
class JTMPN(nn.Module): | |
def __init__(self, hidden_size, depth): | |
super(JTMPN, self).__init__() | |
self.hidden_size = hidden_size | |
self.depth = depth | |
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False) | |
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) | |
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size) | |
def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0) | |
fatoms = create_var(fatoms) | |
fbonds = create_var(fbonds) | |
agraph = create_var(agraph) | |
bgraph = create_var(bgraph) | |
binput = self.W_i(fbonds) | |
graph_message = F.relu(binput) | |
for i in range(self.depth - 1): | |
message = torch.cat([tree_message,graph_message], dim=0) | |
nei_message = index_select_ND(message, 0, bgraph) | |
nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0) | |
nei_message = self.W_h(nei_message) | |
graph_message = F.relu(binput + nei_message) | |
message = torch.cat([tree_message,graph_message], dim=0) | |
nei_message = index_select_ND(message, 0, agraph) | |
nei_message = nei_message.sum(dim=1) | |
ainput = torch.cat([fatoms, nei_message], dim=1) | |
atom_hiddens = F.relu(self.W_o(ainput)) | |
mol_vecs = [] | |
for st,le in scope: | |
mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le | |
mol_vecs.append(mol_vec) | |
mol_vecs = torch.stack(mol_vecs, dim=0) | |
return mol_vecs | |
def tensorize(cand_batch, mess_dict): | |
fatoms,fbonds = [],[] | |
in_bonds,all_bonds = [],[] | |
total_atoms = 0 | |
total_mess = len(mess_dict) + 1 #must include vec(0) padding | |
scope = [] | |
for smiles,all_nodes,ctr_node in cand_batch: | |
mol = Chem.MolFromSmiles(smiles) | |
Chem.Kekulize(mol) #The original jtnn version kekulizes. Need to revisit why it is necessary | |
n_atoms = mol.GetNumAtoms() | |
ctr_bid = ctr_node.idx | |
for atom in mol.GetAtoms(): | |
fatoms.append( atom_features(atom) ) | |
in_bonds.append([]) | |
for bond in mol.GetBonds(): | |
a1 = bond.GetBeginAtom() | |
a2 = bond.GetEndAtom() | |
x = a1.GetIdx() + total_atoms | |
y = a2.GetIdx() + total_atoms | |
#Here x_nid,y_nid could be 0 | |
x_nid,y_nid = a1.GetAtomMapNum(),a2.GetAtomMapNum() | |
x_bid = all_nodes[x_nid - 1].idx if x_nid > 0 else -1 | |
y_bid = all_nodes[y_nid - 1].idx if y_nid > 0 else -1 | |
bfeature = bond_features(bond) | |
b = total_mess + len(all_bonds) #bond idx offseted by total_mess | |
all_bonds.append((x,y)) | |
fbonds.append( torch.cat([fatoms[x], bfeature], 0) ) | |
in_bonds[y].append(b) | |
b = total_mess + len(all_bonds) | |
all_bonds.append((y,x)) | |
fbonds.append( torch.cat([fatoms[y], bfeature], 0) ) | |
in_bonds[x].append(b) | |
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: | |
if (x_bid,y_bid) in mess_dict: | |
mess_idx = mess_dict[(x_bid,y_bid)] | |
in_bonds[y].append(mess_idx) | |
if (y_bid,x_bid) in mess_dict: | |
mess_idx = mess_dict[(y_bid,x_bid)] | |
in_bonds[x].append(mess_idx) | |
scope.append((total_atoms,n_atoms)) | |
total_atoms += n_atoms | |
total_bonds = len(all_bonds) | |
fatoms = torch.stack(fatoms, 0) | |
fbonds = torch.stack(fbonds, 0) | |
agraph = torch.zeros(total_atoms,MAX_NB).long() | |
bgraph = torch.zeros(total_bonds,MAX_NB).long() | |
for a in range(total_atoms): | |
for i,b in enumerate(in_bonds[a]): | |
agraph[a,i] = b | |
for b1 in range(total_bonds): | |
x,y = all_bonds[b1] | |
for i,b2 in enumerate(in_bonds[x]): #b2 is offseted by total_mess | |
if b2 < total_mess or all_bonds[b2-total_mess][0] != y: | |
bgraph[b1,i] = b2 | |
return (fatoms, fbonds, agraph, bgraph, scope) | |