|
|
|
|
|
|
|
|
|
|
|
import os |
|
import math |
|
import model |
|
import torch |
|
import json |
|
import pickle |
|
import numpy as np |
|
from rdkit import Chem |
|
from Bio import SeqIO |
|
from collections import Counter |
|
from collections import defaultdict |
|
import matplotlib.pyplot as plt |
|
import matplotlib.pyplot as plt |
|
from matplotlib import rc |
|
from scipy import stats |
|
import seaborn as sns |
|
import pandas as pd |
|
from scipy.stats import ranksums |
|
from sklearn.metrics import mean_squared_error,r2_score |
|
|
|
|
|
fingerprint_dict = model.load_pickle('../../Data/input/fingerprint_dict.pickle') |
|
atom_dict = model.load_pickle('../../Data/input/atom_dict.pickle') |
|
bond_dict = model.load_pickle('../../Data/input/bond_dict.pickle') |
|
edge_dict = model.load_pickle('../../Data/input/edge_dict.pickle') |
|
word_dict = model.load_pickle('../../Data/input/sequence_dict.pickle') |
|
|
|
def split_sequence(sequence, ngram): |
|
sequence = '-' + sequence + '=' |
|
|
|
|
|
|
|
words = list() |
|
for i in range(len(sequence)-ngram+1) : |
|
try : |
|
words.append(word_dict[sequence[i:i+ngram]]) |
|
except : |
|
word_dict[sequence[i:i+ngram]] = 0 |
|
words.append(word_dict[sequence[i:i+ngram]]) |
|
|
|
return np.array(words) |
|
|
|
|
|
def create_atoms(mol): |
|
"""Create a list of atom (e.g., hydrogen and oxygen) IDs |
|
considering the aromaticity.""" |
|
|
|
atoms = [a.GetSymbol() for a in mol.GetAtoms()] |
|
|
|
for a in mol.GetAromaticAtoms(): |
|
i = a.GetIdx() |
|
atoms[i] = (atoms[i], 'aromatic') |
|
atoms = [atom_dict[a] for a in atoms] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return np.array(atoms) |
|
|
|
def create_ijbonddict(mol): |
|
"""Create a dictionary, which each key is a node ID |
|
and each value is the tuples of its neighboring node |
|
and bond (e.g., single and double) IDs.""" |
|
|
|
i_jbond_dict = defaultdict(lambda: []) |
|
for b in mol.GetBonds(): |
|
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() |
|
bond = bond_dict[str(b.GetBondType())] |
|
i_jbond_dict[i].append((j, bond)) |
|
i_jbond_dict[j].append((i, bond)) |
|
return i_jbond_dict |
|
|
|
def extract_fingerprints(atoms, i_jbond_dict, radius): |
|
"""Extract the r-radius subgraphs (i.e., fingerprints) |
|
from a molecular graph using Weisfeiler-Lehman algorithm.""" |
|
|
|
|
|
|
|
|
|
if (len(atoms) == 1) or (radius == 0): |
|
fingerprints = [fingerprint_dict[a] for a in atoms] |
|
|
|
else: |
|
nodes = atoms |
|
i_jedge_dict = i_jbond_dict |
|
|
|
for _ in range(radius): |
|
|
|
"""Update each node ID considering its neighboring nodes and edges |
|
(i.e., r-radius subgraphs or fingerprints).""" |
|
fingerprints = [] |
|
for i, j_edge in i_jedge_dict.items(): |
|
neighbors = [(nodes[j], edge) for j, edge in j_edge] |
|
fingerprint = (nodes[i], tuple(sorted(neighbors))) |
|
|
|
|
|
try : |
|
fingerprints.append(fingerprint_dict[fingerprint]) |
|
except : |
|
fingerprint_dict[fingerprint] = 0 |
|
fingerprints.append(fingerprint_dict[fingerprint]) |
|
|
|
nodes = fingerprints |
|
|
|
"""Also update each edge ID considering two nodes |
|
on its both sides.""" |
|
_i_jedge_dict = defaultdict(lambda: []) |
|
for i, j_edge in i_jedge_dict.items(): |
|
for j, edge in j_edge: |
|
both_side = tuple(sorted((nodes[i], nodes[j]))) |
|
|
|
|
|
try : |
|
edge = edge_dict[(both_side, edge)] |
|
except : |
|
edge_dict[(both_side, edge)] = 0 |
|
edge = edge_dict[(both_side, edge)] |
|
|
|
_i_jedge_dict[i].append((j, edge)) |
|
i_jedge_dict = _i_jedge_dict |
|
|
|
return np.array(fingerprints) |
|
|
|
def create_adjacency(mol): |
|
adjacency = Chem.GetAdjacencyMatrix(mol) |
|
return np.array(adjacency) |
|
|
|
def dump_dictionary(dictionary, filename): |
|
with open(filename, 'wb') as file: |
|
pickle.dump(dict(dictionary), file) |
|
|
|
def load_tensor(file_name, dtype): |
|
return [dtype(d).to(device) for d in np.load(file_name + '.npy', allow_pickle=True)] |
|
|
|
class Predictor(object): |
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def predict(self, data): |
|
predicted_value = self.model.forward(data) |
|
|
|
return predicted_value |
|
|
|
def extract_wildtype_mutant() : |
|
with open('../../Data/database/Kcat_combination_0918_wildtype_mutant.json', 'r') as infile : |
|
Kcat_data = json.load(infile) |
|
|
|
entry_keys = list() |
|
for data in Kcat_data : |
|
|
|
|
|
|
|
|
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
entry_key = substrate + '&' + organism + '&' + EC |
|
|
|
entry_keys.append(entry_key) |
|
|
|
entry_dict = dict(Counter(entry_keys)) |
|
|
|
|
|
duplicated_keys = [key for key, value in entry_dict.items() if value > 1] |
|
|
|
|
|
duplicated_dict = {key:value for key, value in entry_dict.items() if value > 1} |
|
|
|
|
|
|
|
duplicated_list = sorted(duplicated_dict.items(), key=lambda x: x[1], reverse=True)[:30] |
|
|
|
for duplicated in duplicated_list[:1] : |
|
|
|
for data in Kcat_data : |
|
|
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
one_entry = substrate + '&' + organism + '&' + EC |
|
if one_entry == duplicated[0] : |
|
enzyme_type = data['Type'] |
|
Kcat_value = data['Value'] |
|
|
|
|
|
|
|
|
|
return duplicated_list |
|
|
|
def extract_wildtype_kcat(entry) : |
|
with open('../../Data/database/Kcat_combination_0918_wildtype_mutant.json', 'r') as infile : |
|
Kcat_data = json.load(infile) |
|
|
|
for data in Kcat_data : |
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
one_entry = substrate + '&' + organism + '&' + EC |
|
if one_entry == entry : |
|
enzyme_type = data['Type'] |
|
if enzyme_type == 'wildtype' : |
|
wildtype_kcat = float(data['Value']) |
|
|
|
if wildtype_kcat : |
|
return wildtype_kcat |
|
else : |
|
return None |
|
|
|
def compare_prediction_wildtype_mutant() : |
|
with open('../../Data/database/Kcat_combination_0918_wildtype_mutant.json', 'r') as infile : |
|
Kcat_data = json.load(infile) |
|
|
|
wildtype_mutant_entries = extract_wildtype_mutant() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fingerprint_dict = model.load_pickle('../../Data/input/fingerprint_dict.pickle') |
|
atom_dict = model.load_pickle('../../Data/input/atom_dict.pickle') |
|
bond_dict = model.load_pickle('../../Data/input/bond_dict.pickle') |
|
word_dict = model.load_pickle('../../Data/input/sequence_dict.pickle') |
|
n_fingerprint = len(fingerprint_dict) |
|
n_word = len(word_dict) |
|
|
|
|
|
|
|
radius=2 |
|
ngram=3 |
|
|
|
dim=10 |
|
layer_gnn=3 |
|
side=5 |
|
window=11 |
|
layer_cnn=3 |
|
layer_output=3 |
|
lr=1e-3 |
|
lr_decay=0.5 |
|
decay_interval=10 |
|
weight_decay=1e-6 |
|
iteration=100 |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
else: |
|
device = torch.device('cpu') |
|
|
|
|
|
Kcat_model = model.KcatPrediction(device, n_fingerprint, n_word, 2*dim, layer_gnn, window, layer_cnn, layer_output).to(device) |
|
Kcat_model.load_state_dict(torch.load('../../Results/output/all--radius2--ngram3--dim20--layer_gnn3--window11--layer_cnn3--layer_output3--lr1e-3--lr_decay0.5--decay_interval10--weight_decay1e-6--iteration50', map_location=device)) |
|
|
|
|
|
predictor = Predictor(Kcat_model) |
|
|
|
print('It\'s time to start the prediction!') |
|
print('-----------------------------------') |
|
|
|
|
|
|
|
i = 0 |
|
alldata = dict() |
|
alldata['substrate'] = list() |
|
alldata['experimental'] = list() |
|
alldata['predicted'] = list() |
|
|
|
experimental_values = list() |
|
predicted_values = list() |
|
|
|
substrate_enzymes = { |
|
'7,8-Dihydrofolate': 'DHFR', |
|
'Glycerate 3-phosphate': 'PGDH', |
|
'L-Aspartate': 'AKIII', |
|
'Penicillin G': 'DAOCS', |
|
'Inosine': 'PNP', |
|
'Isopentenyl diphosphate': 'GGPPs' |
|
} |
|
|
|
for wildtype_mutant_entry in wildtype_mutant_entries : |
|
entry_names = wildtype_mutant_entry[0].split('&') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if entry_names[0] in ['7,8-Dihydrofolate', 'Glycerate 3-phosphate', 'L-Aspartate', 'Penicillin G', 'Inosine', 'Isopentenyl diphosphate'] : |
|
print('This entry is:', entry_names) |
|
for data in Kcat_data : |
|
|
|
|
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
entry = substrate + '&' + organism + '&' + EC |
|
|
|
if entry == wildtype_mutant_entry[0] : |
|
substrate_name = entry_names[0] |
|
|
|
alldata['substrate'].append(substrate_enzymes[substrate_name] + ' & ' + substrate_name) |
|
wildtype_kcat = extract_wildtype_kcat(entry) |
|
|
|
|
|
|
|
i += 1 |
|
|
|
smiles = data['Smiles'] |
|
sequence = data['Sequence'] |
|
enzyme_type = data['Type'] |
|
Kcat = data['Value'] |
|
if "." not in smiles and float(Kcat) > 0: |
|
|
|
|
|
|
|
mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) |
|
atoms = create_atoms(mol) |
|
|
|
i_jbond_dict = create_ijbonddict(mol) |
|
|
|
|
|
fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius) |
|
|
|
|
|
|
|
adjacency = create_adjacency(mol) |
|
|
|
|
|
|
|
words = split_sequence(sequence,ngram) |
|
|
|
|
|
|
|
fingerprints = torch.LongTensor(fingerprints) |
|
adjacency = torch.FloatTensor(adjacency) |
|
words = torch.LongTensor(words) |
|
|
|
inputs = [fingerprints, adjacency, words] |
|
|
|
value = float(data['Value']) |
|
|
|
normalized_value = value/wildtype_kcat |
|
|
|
|
|
|
|
experimental_values.append(math.log10(value)) |
|
alldata['experimental'].append(math.log10(value)) |
|
|
|
prediction = predictor.predict(inputs) |
|
Kcat_log_value = prediction.item() |
|
Kcat_value = math.pow(2,Kcat_log_value) |
|
|
|
|
|
|
|
predicted_values.append(math.log10(Kcat_value)) |
|
alldata['predicted'].append(math.log10(Kcat_value)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
correlation, p_value = stats.pearsonr(experimental_values, predicted_values) |
|
r2 = r2_score(experimental_values,predicted_values) |
|
rmse = np.sqrt(mean_squared_error(experimental_values,predicted_values)) |
|
|
|
print('The overall r is %.4f' % correlation) |
|
print('The overall P value is', p_value) |
|
print('The overall R2 is %.4f' % r2) |
|
print('The overall RMSE is %.4f' % rmse) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
allData = pd.DataFrame(alldata) |
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(1.5,1.5)) |
|
|
|
|
|
|
|
|
|
rc('font',**{'family':'serif','serif':['Helvetica']}) |
|
plt.rcParams['pdf.fonttype'] = 42 |
|
|
|
|
|
plt.axes([0.12,0.12,0.83,0.83]) |
|
|
|
plt.tick_params(direction='in') |
|
plt.tick_params(which='major',length=1.5) |
|
plt.tick_params(which='major',width=0.4) |
|
|
|
palette = ("#FF8C00", "#A034F0", "#159090", "#1051D6", '#0AB944', '#DF16B7') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scatter = sns.scatterplot(data=allData, x='experimental', y='predicted', hue='substrate', |
|
palette=palette, ec='white', s=8, alpha=.7) |
|
|
|
scatter.get_legend().remove() |
|
|
|
plt.rcParams['font.family'] = 'Helvetica' |
|
|
|
scatter.set_xlabel("Experimental $k$$_\mathregular{cat}$ value", fontdict={'weight': 'normal', 'fontname': 'Helvetica', 'size': 7}, fontsize=7) |
|
scatter.set_ylabel('Predicted $k$$_\mathregular{cat}$ value',fontdict={'weight': 'normal', 'fontname': 'Helvetica', 'size': 7},fontsize=7) |
|
|
|
plt.xticks([-5, -3, -1, 1, 3]) |
|
plt.yticks([-5, -3, -1, 1, 3]) |
|
plt.xticks(fontsize=6) |
|
plt.yticks(fontsize=6) |
|
|
|
plt.plot([-5, -3, -1, 1, 3],[-5, -3, -1, 1, 3],color='b',linestyle='dashed',linewidth=1) |
|
|
|
ax = plt.gca() |
|
ax.spines['bottom'].set_linewidth(0.5) |
|
ax.spines['left'].set_linewidth(0.5) |
|
ax.spines['top'].set_linewidth(0.5) |
|
ax.spines['right'].set_linewidth(0.5) |
|
|
|
|
|
plt.legend(bbox_to_anchor=(1.01,1), frameon=False, fontsize=6) |
|
plt.tight_layout() |
|
|
|
plt.savefig("../../Results/figures/Fig3c.pdf", dpi=400, bbox_inches = 'tight') |
|
|
|
|
|
if __name__ == '__main__' : |
|
compare_prediction_wildtype_mutant() |
|
|