Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import re | |
import wandb | |
try: | |
from rdkit import Chem | |
print("Found rdkit, all good") | |
except ModuleNotFoundError as e: | |
use_rdkit = False | |
from warnings import warn | |
warn("Didn't find rdkit, this will fail") | |
assert use_rdkit, "Didn't find rdkit" | |
allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3, 'Si': 4, 'P': [3, 5], | |
'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2], 'Bi': [3, 5], 'Se': [2, 4, 6]} | |
bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, | |
Chem.rdchem.BondType.AROMATIC] | |
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} | |
class BasicMolecularMetrics(object): | |
def __init__(self, dataset_info, train_smiles=None): | |
self.atom_decoder = dataset_info.atom_decoder | |
self.dataset_info = dataset_info | |
# Retrieve dataset smiles only for qm9 currently. | |
self.dataset_smiles_list = train_smiles | |
def compute_validity(self, generated): | |
""" generated: list of couples (positions, atom_types)""" | |
valid = [] | |
num_components = [] | |
all_smiles = [] | |
for graph in generated: | |
atom_types, edge_types = graph | |
mol = build_molecule(atom_types, edge_types, self.dataset_info.atom_decoder) | |
smiles = mol2smiles(mol) | |
try: | |
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | |
num_components.append(len(mol_frags)) | |
except: | |
pass | |
if smiles is not None: | |
try: | |
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | |
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) | |
smiles = mol2smiles(largest_mol) | |
valid.append(smiles) | |
all_smiles.append(smiles) | |
except Chem.rdchem.AtomValenceException: | |
print("Valence error in GetmolFrags") | |
all_smiles.append(None) | |
except Chem.rdchem.KekulizeException: | |
print("Can't kekulize molecule") | |
all_smiles.append(None) | |
else: | |
all_smiles.append(None) | |
return valid, len(valid) / len(generated), np.array(num_components), all_smiles | |
def compute_uniqueness(self, valid): | |
""" valid: list of SMILES strings.""" | |
return list(set(valid)), len(set(valid)) / len(valid) | |
def compute_novelty(self, unique): | |
num_novel = 0 | |
novel = [] | |
if self.dataset_smiles_list is None: | |
print("Dataset smiles is None, novelty computation skipped") | |
return 1, 1 | |
for smiles in unique: | |
if smiles not in self.dataset_smiles_list: | |
novel.append(smiles) | |
num_novel += 1 | |
return novel, num_novel / len(unique) | |
def compute_relaxed_validity(self, generated): | |
valid = [] | |
for graph in generated: | |
atom_types, edge_types = graph | |
mol = build_molecule_with_partial_charges(atom_types, edge_types, self.dataset_info.atom_decoder) | |
smiles = mol2smiles(mol) | |
if smiles is not None: | |
try: | |
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | |
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) | |
smiles = mol2smiles(largest_mol) | |
valid.append(smiles) | |
except Chem.rdchem.AtomValenceException: | |
print("Valence error in GetmolFrags") | |
except Chem.rdchem.KekulizeException: | |
print("Can't kekulize molecule") | |
return valid, len(valid) / len(generated) | |
def evaluate(self, generated): | |
""" generated: list of pairs (positions: n x 3, atom_types: n [int]) | |
the positions and atom types should already be masked. """ | |
valid, validity, num_components, all_smiles = self.compute_validity(generated) | |
nc_mu = num_components.mean() if len(num_components) > 0 else 0 | |
nc_min = num_components.min() if len(num_components) > 0 else 0 | |
nc_max = num_components.max() if len(num_components) > 0 else 0 | |
print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}%") | |
print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}") | |
relaxed_valid, relaxed_validity = self.compute_relaxed_validity(generated) | |
print(f"Relaxed validity over {len(generated)} molecules: {relaxed_validity * 100 :.2f}%") | |
if relaxed_validity > 0: | |
unique, uniqueness = self.compute_uniqueness(relaxed_valid) | |
print(f"Uniqueness over {len(relaxed_valid)} valid molecules: {uniqueness * 100 :.2f}%") | |
if self.dataset_smiles_list is not None: | |
_, novelty = self.compute_novelty(unique) | |
print(f"Novelty over {len(unique)} unique valid molecules: {novelty * 100 :.2f}%") | |
else: | |
novelty = -1.0 | |
else: | |
novelty = -1.0 | |
uniqueness = 0.0 | |
unique = [] | |
return ([validity, relaxed_validity, uniqueness, novelty], unique, | |
dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles) | |
def mol2smiles(mol): | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return None | |
return Chem.MolToSmiles(mol) | |
def build_molecule(atom_types, edge_types, atom_decoder, verbose=False): | |
if verbose: | |
print("building new molecule") | |
mol = Chem.RWMol() | |
for atom in atom_types: | |
a = Chem.Atom(atom_decoder[atom.item()]) | |
mol.AddAtom(a) | |
if verbose: | |
print("Atom added: ", atom.item(), atom_decoder[atom.item()]) | |
edge_types = torch.triu(edge_types) | |
all_bonds = torch.nonzero(edge_types) | |
for i, bond in enumerate(all_bonds): | |
if bond[0].item() != bond[1].item(): | |
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()]) | |
if verbose: | |
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), | |
bond_dict[edge_types[bond[0], bond[1]].item()] ) | |
return mol | |
def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False): | |
if verbose: | |
print("\nbuilding new molecule") | |
mol = Chem.RWMol() | |
for atom in atom_types: | |
a = Chem.Atom(atom_decoder[atom.item()]) | |
mol.AddAtom(a) | |
if verbose: | |
print("Atom added: ", atom.item(), atom_decoder[atom.item()]) | |
edge_types = torch.triu(edge_types) | |
all_bonds = torch.nonzero(edge_types) | |
for i, bond in enumerate(all_bonds): | |
if bond[0].item() != bond[1].item(): | |
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()]) | |
if verbose: | |
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), | |
bond_dict[edge_types[bond[0], bond[1]].item()]) | |
# add formal charge to atom: e.g. [O+], [N+], [S+] | |
# not support [O-], [N-], [S-], [NH+] etc. | |
flag, atomid_valence = check_valency(mol) | |
if verbose: | |
print("flag, valence", flag, atomid_valence) | |
if flag: | |
continue | |
else: | |
assert len(atomid_valence) == 2 | |
idx = atomid_valence[0] | |
v = atomid_valence[1] | |
an = mol.GetAtomWithIdx(idx).GetAtomicNum() | |
if verbose: | |
print("atomic num of atom with a large valence", an) | |
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: | |
mol.GetAtomWithIdx(idx).SetFormalCharge(1) | |
# print("Formal charge added") | |
return mol | |
# Functions from GDSS | |
def check_valency(mol): | |
try: | |
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) | |
return True, None | |
except ValueError as e: | |
e = str(e) | |
p = e.find('#') | |
e_sub = e[p:] | |
atomid_valence = list(map(int, re.findall(r'\d+', e_sub))) | |
return False, atomid_valence | |
def correct_mol(m): | |
# xsm = Chem.MolToSmiles(x, isomericSmiles=True) | |
mol = m | |
##### | |
no_correct = False | |
flag, _ = check_valency(mol) | |
if flag: | |
no_correct = True | |
while True: | |
flag, atomid_valence = check_valency(mol) | |
if flag: | |
break | |
else: | |
assert len(atomid_valence) == 2 | |
idx = atomid_valence[0] | |
v = atomid_valence[1] | |
queue = [] | |
check_idx = 0 | |
for b in mol.GetAtomWithIdx(idx).GetBonds(): | |
type = int(b.GetBondType()) | |
queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx())) | |
if type == 12: | |
check_idx += 1 | |
queue.sort(key=lambda tup: tup[1], reverse=True) | |
if queue[-1][1] == 12: | |
return None, no_correct | |
elif len(queue) > 0: | |
start = queue[check_idx][2] | |
end = queue[check_idx][3] | |
t = queue[check_idx][1] - 1 | |
mol.RemoveBond(start, end) | |
if t >= 1: | |
mol.AddBond(start, end, bond_dict[t]) | |
return mol, no_correct | |
def valid_mol_can_with_seg(m, largest_connected_comp=True): | |
if m is None: | |
return None | |
sm = Chem.MolToSmiles(m, isomericSmiles=True) | |
if largest_connected_comp and '.' in sm: | |
vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.') | |
vsm.sort(key=lambda tup: tup[1], reverse=True) | |
mol = Chem.MolFromSmiles(vsm[0][0]) | |
else: | |
mol = Chem.MolFromSmiles(sm) | |
return mol | |
if __name__ == '__main__': | |
smiles_mol = 'C1CCC1' | |
print("Smiles mol %s" % smiles_mol) | |
chem_mol = Chem.MolFromSmiles(smiles_mol) | |
block_mol = Chem.MolToMolBlock(chem_mol) | |
print("Block mol:") | |
print(block_mol) | |
use_rdkit = True | |
def check_stability(atom_types, edge_types, dataset_info, debug=False,atom_decoder=None): | |
if atom_decoder is None: | |
atom_decoder = dataset_info.atom_decoder | |
n_bonds = np.zeros(len(atom_types), dtype='int') | |
for i in range(len(atom_types)): | |
for j in range(i + 1, len(atom_types)): | |
n_bonds[i] += abs((edge_types[i, j] + edge_types[j, i])/2) | |
n_bonds[j] += abs((edge_types[i, j] + edge_types[j, i])/2) | |
n_stable_bonds = 0 | |
for atom_type, atom_n_bond in zip(atom_types, n_bonds): | |
possible_bonds = allowed_bonds[atom_decoder[atom_type]] | |
if type(possible_bonds) == int: | |
is_stable = possible_bonds == atom_n_bond | |
else: | |
is_stable = atom_n_bond in possible_bonds | |
if not is_stable and debug: | |
print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type], atom_n_bond)) | |
n_stable_bonds += int(is_stable) | |
molecule_stable = n_stable_bonds == len(atom_types) | |
return molecule_stable, n_stable_bonds, len(atom_types) | |
def compute_molecular_metrics(molecule_list, train_smiles, dataset_info): | |
""" molecule_list: (dict) """ | |
if not dataset_info.remove_h: | |
print(f'Analyzing molecule stability...') | |
molecule_stable = 0 | |
nr_stable_bonds = 0 | |
n_atoms = 0 | |
n_molecules = len(molecule_list) | |
for i, mol in enumerate(molecule_list): | |
atom_types, edge_types = mol | |
validity_results = check_stability(atom_types, edge_types, dataset_info) | |
molecule_stable += int(validity_results[0]) | |
nr_stable_bonds += int(validity_results[1]) | |
n_atoms += int(validity_results[2]) | |
# Validity | |
fraction_mol_stable = molecule_stable / float(n_molecules) | |
fraction_atm_stable = nr_stable_bonds / float(n_atoms) | |
validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable} | |
if wandb.run: | |
wandb.log(validity_dict) | |
else: | |
validity_dict = {'mol_stable': -1, 'atm_stable': -1} | |
metrics = BasicMolecularMetrics(dataset_info, train_smiles) | |
rdkit_metrics = metrics.evaluate(molecule_list) | |
all_smiles = rdkit_metrics[-1] | |
if wandb.run: | |
nc = rdkit_metrics[-2] | |
dic = {'Validity': rdkit_metrics[0][0], 'Relaxed Validity': rdkit_metrics[0][1], | |
'Uniqueness': rdkit_metrics[0][2], 'Novelty': rdkit_metrics[0][3], | |
'nc_max': nc['nc_max'], 'nc_mu': nc['nc_mu']} | |
wandb.log(dic) | |
return validity_dict, rdkit_metrics, all_smiles | |