LGGM-Text2Graph / analysis /visualization.py
YuWang0103's picture
Upload 41 files
6b59850 verified
import os
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Geometry import Point3D
from rdkit import RDLogger
import imageio
import networkx as nx
import numpy as np
import rdkit.Chem
import wandb
import matplotlib.pyplot as plt
class MolecularVisualization:
def __init__(self, remove_h, dataset_infos):
self.remove_h = remove_h
self.dataset_infos = dataset_infos
def mol_from_graphs(self, node_list, adjacency_matrix):
"""
Convert graphs to rdkit molecules
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
# dictionary to map integer value to the char of atom
atom_decoder = self.dataset_infos.atom_decoder
# create empty editable mol object
mol = Chem.RWMol()
# add atoms to mol and keep track of index
node_to_idx = {}
for i in range(len(node_list)):
if node_list[i] == -1:
continue
a = Chem.Atom(atom_decoder[int(node_list[i])])
molIdx = mol.AddAtom(a)
node_to_idx[i] = molIdx
for ix, row in enumerate(adjacency_matrix):
for iy, bond in enumerate(row):
# only traverse half the symmetric matrix
if iy <= ix:
continue
if bond == 1:
bond_type = Chem.rdchem.BondType.SINGLE
elif bond == 2:
bond_type = Chem.rdchem.BondType.DOUBLE
elif bond == 3:
bond_type = Chem.rdchem.BondType.TRIPLE
elif bond == 4:
bond_type = Chem.rdchem.BondType.AROMATIC
else:
continue
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
try:
mol = mol.GetMol()
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
mol = None
return mol
def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'):
# define path to save figures
if not os.path.exists(path):
os.makedirs(path)
# visualize the final molecules
print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}")
if num_molecules_to_visualize > len(molecules):
print(f"Shortening to {len(molecules)}")
num_molecules_to_visualize = len(molecules)
for i in range(num_molecules_to_visualize):
file_path = os.path.join(path, 'molecule_{}.png'.format(i))
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
try:
Draw.MolToFile(mol, file_path)
if wandb.run and log is not None:
print(f"Saving {file_path} to wandb")
wandb.log({log: wandb.Image(file_path)}, commit=True)
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None):
RDLogger.DisableLog('rdApp.*')
# convert graphs to the rdkit molecules
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
# find the coordinates of atoms in the final molecule
final_molecule = mols[-1]
AllChem.Compute2DCoords(final_molecule)
coords = []
for i, atom in enumerate(final_molecule.GetAtoms()):
positions = final_molecule.GetConformer().GetAtomPosition(i)
coords.append((positions.x, positions.y, positions.z))
# align all the molecules
for i, mol in enumerate(mols):
AllChem.Compute2DCoords(mol)
conf = mol.GetConformer()
for j, atom in enumerate(mol.GetAtoms()):
x, y, z = coords[j]
conf.SetAtomPosition(j, Point3D(x, y, z))
# draw gif
save_paths = []
num_frams = nodes_list.shape[0]
for frame in range(num_frams):
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}")
save_paths.append(file_name)
imgs = [imageio.imread(fn) for fn in save_paths]
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
if wandb.run:
print(f"Saving {gif_path} to wandb")
wandb.log({"chain": wandb.Video(gif_path, fps=5, format="gif")}, commit=True)
# draw grid image
try:
img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200))
img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1])))
except Chem.rdchem.KekulizeException:
print("Can't kekulize molecule")
return mols
class NonMolecularVisualization:
def to_networkx(self, node_list, adjacency_matrix):
"""
Convert graphs to networkx graphs
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
graph = nx.Graph()
for i in range(len(node_list)):
if node_list[i] == -1:
continue
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
rows, cols = np.where(adjacency_matrix >= 1)
edges = zip(rows.tolist(), cols.tolist())
for edge in edges:
edge_type = adjacency_matrix[edge[0]][edge[1]]
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
return graph
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False):
if largest_component:
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
graph = CGs[0]
# Plot the graph structure with colors
if pos is None:
pos = nx.spring_layout(graph, iterations=iterations)
# Set node colors based on the eigenvectors
w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray())
vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1])
m = max(np.abs(vmin), vmax)
vmin, vmax = -m, m
plt.figure()
nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
plt.tight_layout()
plt.savefig(path)
plt.close("all")
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
# define path to save figures
if not os.path.exists(path):
os.makedirs(path)
# visualize the final molecules
for i in range(num_graphs_to_visualize):
file_path = os.path.join(path, 'graph_{}.png'.format(i))
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
im = plt.imread(file_path)
if wandb.run and log is not None:
wandb.log({log: [wandb.Image(im, caption=file_path)]})
def visualize_chain(self, path, nodes_list, adjacency_matrix):
# convert graphs to networkx
graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
# find the coordinates of atoms in the final molecule
final_graph = graphs[-1]
final_pos = nx.spring_layout(final_graph, seed=0)
# draw gif
save_paths = []
num_frams = nodes_list.shape[0]
for frame in range(num_frams):
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name)
save_paths.append(file_name)
imgs = [imageio.imread(fn) for fn in save_paths]
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
if wandb.run:
wandb.log({'chain': [wandb.Video(gif_path, caption=gif_path, format="gif")]})