Spaces:
Runtime error
Runtime error
import os | |
from rdkit import Chem | |
from rdkit.Chem import Draw, AllChem | |
from rdkit.Geometry import Point3D | |
from rdkit import RDLogger | |
import imageio | |
import networkx as nx | |
import numpy as np | |
import rdkit.Chem | |
import wandb | |
import matplotlib.pyplot as plt | |
class MolecularVisualization: | |
def __init__(self, remove_h, dataset_infos): | |
self.remove_h = remove_h | |
self.dataset_infos = dataset_infos | |
def mol_from_graphs(self, node_list, adjacency_matrix): | |
""" | |
Convert graphs to rdkit molecules | |
node_list: the nodes of a batch of nodes (bs x n) | |
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) | |
""" | |
# dictionary to map integer value to the char of atom | |
atom_decoder = self.dataset_infos.atom_decoder | |
# create empty editable mol object | |
mol = Chem.RWMol() | |
# add atoms to mol and keep track of index | |
node_to_idx = {} | |
for i in range(len(node_list)): | |
if node_list[i] == -1: | |
continue | |
a = Chem.Atom(atom_decoder[int(node_list[i])]) | |
molIdx = mol.AddAtom(a) | |
node_to_idx[i] = molIdx | |
for ix, row in enumerate(adjacency_matrix): | |
for iy, bond in enumerate(row): | |
# only traverse half the symmetric matrix | |
if iy <= ix: | |
continue | |
if bond == 1: | |
bond_type = Chem.rdchem.BondType.SINGLE | |
elif bond == 2: | |
bond_type = Chem.rdchem.BondType.DOUBLE | |
elif bond == 3: | |
bond_type = Chem.rdchem.BondType.TRIPLE | |
elif bond == 4: | |
bond_type = Chem.rdchem.BondType.AROMATIC | |
else: | |
continue | |
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) | |
try: | |
mol = mol.GetMol() | |
except rdkit.Chem.KekulizeException: | |
print("Can't kekulize molecule") | |
mol = None | |
return mol | |
def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'): | |
# define path to save figures | |
if not os.path.exists(path): | |
os.makedirs(path) | |
# visualize the final molecules | |
print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}") | |
if num_molecules_to_visualize > len(molecules): | |
print(f"Shortening to {len(molecules)}") | |
num_molecules_to_visualize = len(molecules) | |
for i in range(num_molecules_to_visualize): | |
file_path = os.path.join(path, 'molecule_{}.png'.format(i)) | |
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy()) | |
try: | |
Draw.MolToFile(mol, file_path) | |
if wandb.run and log is not None: | |
print(f"Saving {file_path} to wandb") | |
wandb.log({log: wandb.Image(file_path)}, commit=True) | |
except rdkit.Chem.KekulizeException: | |
print("Can't kekulize molecule") | |
def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None): | |
RDLogger.DisableLog('rdApp.*') | |
# convert graphs to the rdkit molecules | |
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] | |
# find the coordinates of atoms in the final molecule | |
final_molecule = mols[-1] | |
AllChem.Compute2DCoords(final_molecule) | |
coords = [] | |
for i, atom in enumerate(final_molecule.GetAtoms()): | |
positions = final_molecule.GetConformer().GetAtomPosition(i) | |
coords.append((positions.x, positions.y, positions.z)) | |
# align all the molecules | |
for i, mol in enumerate(mols): | |
AllChem.Compute2DCoords(mol) | |
conf = mol.GetConformer() | |
for j, atom in enumerate(mol.GetAtoms()): | |
x, y, z = coords[j] | |
conf.SetAtomPosition(j, Point3D(x, y, z)) | |
# draw gif | |
save_paths = [] | |
num_frams = nodes_list.shape[0] | |
for frame in range(num_frams): | |
file_name = os.path.join(path, 'fram_{}.png'.format(frame)) | |
Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}") | |
save_paths.append(file_name) | |
imgs = [imageio.imread(fn) for fn in save_paths] | |
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1])) | |
imgs.extend([imgs[-1]] * 10) | |
imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20) | |
if wandb.run: | |
print(f"Saving {gif_path} to wandb") | |
wandb.log({"chain": wandb.Video(gif_path, fps=5, format="gif")}, commit=True) | |
# draw grid image | |
try: | |
img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200)) | |
img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1]))) | |
except Chem.rdchem.KekulizeException: | |
print("Can't kekulize molecule") | |
return mols | |
class NonMolecularVisualization: | |
def to_networkx(self, node_list, adjacency_matrix): | |
""" | |
Convert graphs to networkx graphs | |
node_list: the nodes of a batch of nodes (bs x n) | |
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) | |
""" | |
graph = nx.Graph() | |
for i in range(len(node_list)): | |
if node_list[i] == -1: | |
continue | |
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i]) | |
rows, cols = np.where(adjacency_matrix >= 1) | |
edges = zip(rows.tolist(), cols.tolist()) | |
for edge in edges: | |
edge_type = adjacency_matrix[edge[0]][edge[1]] | |
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type) | |
return graph | |
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False): | |
if largest_component: | |
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)] | |
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True) | |
graph = CGs[0] | |
# Plot the graph structure with colors | |
if pos is None: | |
pos = nx.spring_layout(graph, iterations=iterations) | |
# Set node colors based on the eigenvectors | |
w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray()) | |
vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1]) | |
m = max(np.abs(vmin), vmax) | |
vmin, vmax = -m, m | |
plt.figure() | |
nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1], | |
cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey') | |
plt.tight_layout() | |
plt.savefig(path) | |
plt.close("all") | |
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'): | |
# define path to save figures | |
if not os.path.exists(path): | |
os.makedirs(path) | |
# visualize the final molecules | |
for i in range(num_graphs_to_visualize): | |
file_path = os.path.join(path, 'graph_{}.png'.format(i)) | |
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy()) | |
self.visualize_non_molecule(graph=graph, pos=None, path=file_path) | |
im = plt.imread(file_path) | |
if wandb.run and log is not None: | |
wandb.log({log: [wandb.Image(im, caption=file_path)]}) | |
def visualize_chain(self, path, nodes_list, adjacency_matrix): | |
# convert graphs to networkx | |
graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] | |
# find the coordinates of atoms in the final molecule | |
final_graph = graphs[-1] | |
final_pos = nx.spring_layout(final_graph, seed=0) | |
# draw gif | |
save_paths = [] | |
num_frams = nodes_list.shape[0] | |
for frame in range(num_frams): | |
file_name = os.path.join(path, 'fram_{}.png'.format(frame)) | |
self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name) | |
save_paths.append(file_name) | |
imgs = [imageio.imread(fn) for fn in save_paths] | |
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1])) | |
imgs.extend([imgs[-1]] * 10) | |
imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20) | |
if wandb.run: | |
wandb.log({'chain': [wandb.Video(gif_path, caption=gif_path, format="gif")]}) | |