|
|
|
|
|
|
|
|
|
|
|
import os |
|
import csv |
|
import math |
|
import subsequence_model |
|
import torch |
|
import json |
|
import pickle |
|
import numpy as np |
|
from rdkit import Chem |
|
from Bio import SeqIO |
|
from collections import Counter |
|
from collections import defaultdict |
|
import matplotlib.pyplot as plt |
|
from matplotlib import rc |
|
from matplotlib.legend_handler import HandlerPathCollection |
|
from scipy import stats |
|
import seaborn as sns |
|
import pandas as pd |
|
from sklearn.metrics import mean_squared_error,r2_score |
|
|
|
|
|
fingerprint_dict = subsequence_model.load_pickle('../../Data/input/fingerprint_dict.pickle') |
|
atom_dict = subsequence_model.load_pickle('../../Data/input/atom_dict.pickle') |
|
bond_dict = subsequence_model.load_pickle('../../Data/input/bond_dict.pickle') |
|
edge_dict = subsequence_model.load_pickle('../../Data/input/edge_dict.pickle') |
|
word_dict = subsequence_model.load_pickle('../../Data/input/sequence_dict.pickle') |
|
|
|
def split_sequence(sequence, ngram): |
|
sequence = '-' + sequence + '=' |
|
|
|
|
|
|
|
words = list() |
|
for i in range(len(sequence)-ngram+1) : |
|
try : |
|
words.append(word_dict[sequence[i:i+ngram]]) |
|
except : |
|
word_dict[sequence[i:i+ngram]] = 0 |
|
words.append(word_dict[sequence[i:i+ngram]]) |
|
|
|
return np.array(words) |
|
|
|
|
|
def create_atoms(mol): |
|
"""Create a list of atom (e.g., hydrogen and oxygen) IDs |
|
considering the aromaticity.""" |
|
|
|
atoms = [a.GetSymbol() for a in mol.GetAtoms()] |
|
|
|
for a in mol.GetAromaticAtoms(): |
|
i = a.GetIdx() |
|
atoms[i] = (atoms[i], 'aromatic') |
|
atoms = [atom_dict[a] for a in atoms] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return np.array(atoms) |
|
|
|
def create_ijbonddict(mol): |
|
"""Create a dictionary, which each key is a node ID |
|
and each value is the tuples of its neighboring node |
|
and bond (e.g., single and double) IDs.""" |
|
|
|
i_jbond_dict = defaultdict(lambda: []) |
|
for b in mol.GetBonds(): |
|
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() |
|
bond = bond_dict[str(b.GetBondType())] |
|
i_jbond_dict[i].append((j, bond)) |
|
i_jbond_dict[j].append((i, bond)) |
|
return i_jbond_dict |
|
|
|
def extract_fingerprints(atoms, i_jbond_dict, radius): |
|
"""Extract the r-radius subgraphs (i.e., fingerprints) |
|
from a molecular graph using Weisfeiler-Lehman algorithm.""" |
|
|
|
|
|
|
|
|
|
if (len(atoms) == 1) or (radius == 0): |
|
fingerprints = [fingerprint_dict[a] for a in atoms] |
|
|
|
else: |
|
nodes = atoms |
|
i_jedge_dict = i_jbond_dict |
|
|
|
for _ in range(radius): |
|
|
|
"""Update each node ID considering its neighboring nodes and edges |
|
(i.e., r-radius subgraphs or fingerprints).""" |
|
fingerprints = [] |
|
for i, j_edge in i_jedge_dict.items(): |
|
neighbors = [(nodes[j], edge) for j, edge in j_edge] |
|
fingerprint = (nodes[i], tuple(sorted(neighbors))) |
|
|
|
|
|
try : |
|
fingerprints.append(fingerprint_dict[fingerprint]) |
|
except : |
|
fingerprint_dict[fingerprint] = 0 |
|
fingerprints.append(fingerprint_dict[fingerprint]) |
|
|
|
nodes = fingerprints |
|
|
|
"""Also update each edge ID considering two nodes |
|
on its both sides.""" |
|
_i_jedge_dict = defaultdict(lambda: []) |
|
for i, j_edge in i_jedge_dict.items(): |
|
for j, edge in j_edge: |
|
both_side = tuple(sorted((nodes[i], nodes[j]))) |
|
|
|
|
|
try : |
|
edge = edge_dict[(both_side, edge)] |
|
except : |
|
edge_dict[(both_side, edge)] = 0 |
|
edge = edge_dict[(both_side, edge)] |
|
|
|
_i_jedge_dict[i].append((j, edge)) |
|
i_jedge_dict = _i_jedge_dict |
|
|
|
return np.array(fingerprints) |
|
|
|
def create_adjacency(mol): |
|
adjacency = Chem.GetAdjacencyMatrix(mol) |
|
return np.array(adjacency) |
|
|
|
def dump_dictionary(dictionary, filename): |
|
with open(filename, 'wb') as file: |
|
pickle.dump(dict(dictionary), file) |
|
|
|
def load_tensor(file_name, dtype): |
|
return [dtype(d).to(device) for d in np.load(file_name + '.npy', allow_pickle=True)] |
|
|
|
|
|
|
|
marker_size = 25 |
|
def update_prop(handle, orig): |
|
handle.update_from(orig) |
|
handle.set_sizes([marker_size]) |
|
|
|
def plot_attention_weights(attention_profiles, wildtype_like_positions, wildtype_decreased_positions, wildtype_like, wildtype_decreased) : |
|
positions = list() |
|
weights = list() |
|
i = 0 |
|
for attention in attention_profiles : |
|
i += 1 |
|
positions.append(i) |
|
weights.append(float(attention)) |
|
|
|
plt.figure(figsize=(2.0,1.5)) |
|
|
|
|
|
|
|
rc('font',**{'family':'serif','serif':['Helvetica']}) |
|
plt.rcParams['pdf.fonttype'] = 42 |
|
|
|
plt.axes([0.12,0.12,0.83,0.83]) |
|
|
|
|
|
|
|
|
|
plt.tick_params(direction='in') |
|
plt.tick_params(which='major',length=1.5) |
|
plt.tick_params(which='major',width=0.4) |
|
|
|
plt.plot(positions, weights, color='k', linestyle='--', linewidth=0.75) |
|
|
|
|
|
|
|
|
|
|
|
|
|
print(Counter(wildtype_like_positions)) |
|
print(Counter(wildtype_decreased_positions)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.scatter(wildtype_like_positions, wildtype_like, s=[Counter(wildtype_like_positions)[position]*5 for position in wildtype_like_positions], color='#2166ac', marker='o', label='Wildtype_like') |
|
sc = plt.scatter(wildtype_decreased_positions, wildtype_decreased, s=[Counter(wildtype_decreased_positions)[position]*5 for position in wildtype_decreased_positions], color='#b2182b', marker='o', label='Wildtype_decreased') |
|
|
|
plt.rcParams['font.family'] = 'Helvetica' |
|
plt.xlabel('Residue position', fontsize=7) |
|
plt.ylabel('Attention weight', fontsize=7) |
|
|
|
|
|
plt.xticks([0,50,100,150,200,250,300]) |
|
plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4]) |
|
|
|
plt.xticks(fontsize=6) |
|
plt.yticks(fontsize=6) |
|
|
|
ax = plt.gca() |
|
ax.spines['bottom'].set_linewidth(0.5) |
|
ax.spines['left'].set_linewidth(0.5) |
|
ax.spines['top'].set_linewidth(0.5) |
|
ax.spines['right'].set_linewidth(0.5) |
|
|
|
|
|
|
|
plt.legend(handler_map={type(sc): HandlerPathCollection(update_func=update_prop)}, frameon=False, markerscale=2.0, numpoints=1, prop={"size":6}) |
|
|
|
|
|
|
|
|
|
|
|
plt.savefig("../../Results/figures/Fig3e.pdf", dpi=400, bbox_inches = 'tight') |
|
|
|
class Predictor(object): |
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def predict(self, data): |
|
predicted_value,attention_profiles = self.model.forward(data) |
|
|
|
return predicted_value, attention_profiles |
|
|
|
def extract_wildtype_mutant() : |
|
with open('../../Data/database/Kcat_combination_0918_wildtype_mutant.json', 'r') as infile : |
|
Kcat_data = json.load(infile) |
|
|
|
entry_keys = list() |
|
for data in Kcat_data : |
|
|
|
|
|
|
|
|
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
entry_key = substrate + '&' + organism + '&' + EC |
|
|
|
entry_keys.append(entry_key) |
|
|
|
entry_dict = dict(Counter(entry_keys)) |
|
|
|
|
|
duplicated_keys = [key for key, value in entry_dict.items() if value > 1] |
|
|
|
|
|
duplicated_dict = {key:value for key, value in entry_dict.items() if value > 1} |
|
|
|
|
|
|
|
duplicated_list = sorted(duplicated_dict.items(), key=lambda x: x[1], reverse=True)[:30] |
|
|
|
for duplicated in duplicated_list[:1] : |
|
|
|
for data in Kcat_data : |
|
|
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
one_entry = substrate + '&' + organism + '&' + EC |
|
if one_entry == duplicated[0] : |
|
enzyme_type = data['Type'] |
|
Kcat_value = data['Value'] |
|
|
|
|
|
|
|
|
|
return duplicated_list |
|
|
|
def compare_list(mutant, wildtype) : |
|
different_attentions = list() |
|
for i in range(0, len(wildtype)) : |
|
if mutant[i] != wildtype[i] : |
|
different_attentions.append(mutant[i]) |
|
else : |
|
continue |
|
|
|
return different_attentions |
|
|
|
def compare_mutant_wildtype_sequence(mutant, wildtype) : |
|
different_positions = list() |
|
for i in range(0, len(wildtype)) : |
|
if mutant[i] != wildtype[i] : |
|
|
|
different_positions.append(i) |
|
else : |
|
continue |
|
|
|
return different_positions |
|
|
|
def extract_wildtype_kcat(entry) : |
|
with open('../../Data/database/Kcat_combination_0918_wildtype_mutant.json', 'r') as infile : |
|
Kcat_data = json.load(infile) |
|
|
|
for data in Kcat_data : |
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
one_entry = substrate + '&' + organism + '&' + EC |
|
if one_entry == entry : |
|
enzyme_type = data['Type'] |
|
if enzyme_type == 'wildtype' : |
|
wildtype_kcat = float(data['Value']) |
|
|
|
if wildtype_kcat : |
|
return wildtype_kcat |
|
else : |
|
return None |
|
|
|
def extract_wildtype_sequence(entry) : |
|
with open('../../Data/database/Kcat_combination_0918_wildtype_mutant.json', 'r') as infile : |
|
Kcat_data = json.load(infile) |
|
|
|
for data in Kcat_data : |
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
one_entry = substrate + '&' + organism + '&' + EC |
|
if one_entry == entry : |
|
enzyme_type = data['Type'] |
|
if enzyme_type == 'wildtype' : |
|
wildtype_sequence = data['Sequence'] |
|
|
|
if wildtype_sequence : |
|
return wildtype_sequence |
|
else : |
|
return None |
|
|
|
def extract_wildtype_attention(wildtype_entry) : |
|
with open('../../Data/database/Kcat_combination_0918_wildtype_mutant.json', 'r') as infile : |
|
Kcat_data = json.load(infile) |
|
|
|
fingerprint_dict = subsequence_model.load_pickle('../../Data/input/fingerprint_dict.pickle') |
|
atom_dict = subsequence_model.load_pickle('../../Data/input/atom_dict.pickle') |
|
bond_dict = subsequence_model.load_pickle('../../Data/input/bond_dict.pickle') |
|
word_dict = subsequence_model.load_pickle('../../Data/input/sequence_dict.pickle') |
|
n_fingerprint = len(fingerprint_dict) |
|
n_word = len(word_dict) |
|
|
|
|
|
|
|
radius=2 |
|
ngram=3 |
|
dim=10 |
|
layer_gnn=3 |
|
side=5 |
|
window=11 |
|
layer_cnn=3 |
|
layer_output=3 |
|
lr=1e-3 |
|
lr_decay=0.5 |
|
decay_interval=10 |
|
weight_decay=1e-6 |
|
iteration=100 |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
else: |
|
device = torch.device('cpu') |
|
|
|
|
|
Kcat_model = subsequence_model.KcatPrediction(device, n_fingerprint, n_word, 2*dim, layer_gnn, window, layer_cnn, layer_output).to(device) |
|
Kcat_model.load_state_dict(torch.load('../../Results/output/all--radius2--ngram3--dim20--layer_gnn3--window11--layer_cnn3--layer_output3--lr1e-3--lr_decay0.5--decay_interval10--weight_decay1e-6--iteration50', map_location=device)) |
|
|
|
|
|
predictor = Predictor(Kcat_model) |
|
|
|
for data in Kcat_data : |
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
enzyme_type = data['Type'] |
|
entry = substrate + '&' + organism + '&' + EC |
|
|
|
if entry == wildtype_entry and enzyme_type == 'wildtype' : |
|
smiles = data['Smiles'] |
|
sequence = data['Sequence'] |
|
Kcat = data['Value'] |
|
if "." not in smiles and float(Kcat) > 0: |
|
|
|
mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) |
|
atoms = create_atoms(mol) |
|
|
|
i_jbond_dict = create_ijbonddict(mol) |
|
|
|
|
|
fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius) |
|
|
|
|
|
|
|
adjacency = create_adjacency(mol) |
|
|
|
|
|
|
|
words = split_sequence(sequence,ngram) |
|
|
|
|
|
|
|
fingerprints = torch.LongTensor(fingerprints) |
|
adjacency = torch.FloatTensor(adjacency) |
|
words = torch.LongTensor(words) |
|
|
|
inputs = [fingerprints, adjacency, words] |
|
|
|
value = float(data['Value']) |
|
|
|
|
|
prediction, wildtype_attention_profiles = predictor.predict(inputs) |
|
|
|
|
|
|
|
|
|
return wildtype_attention_profiles, sequence |
|
|
|
def output_wildtype_enzyme(wildtype_attention_profiles, sequence) : |
|
sequence_length = len(sequence) |
|
attention_weights_length = len(wildtype_attention_profiles) |
|
|
|
print('The length of wildtype enzyme:', sequence_length) |
|
print('The length of attention weights:', attention_weights_length) |
|
print(sequence) |
|
print(wildtype_attention_profiles) |
|
|
|
with open('../../Results/output/supple_wildtype_PNP_attention_weights.tsv', 'w') as outfile : |
|
i = 0 |
|
items = ['Sequence position', 'Amino acid', 'Attention weight'] |
|
outfile.write('\t'.join(items) + '\n') |
|
for attention in wildtype_attention_profiles : |
|
i += 1 |
|
line = [str(i), sequence[i-1], attention] |
|
outfile.write('\t'.join(line) + '\n') |
|
|
|
def wildtype_like_decreased_info() : |
|
with open('../../Data/database/Kcat_combination_0918_wildtype_mutant.json', 'r') as infile : |
|
Kcat_data = json.load(infile) |
|
|
|
wildtype_mutant_entries = extract_wildtype_mutant() |
|
|
|
fingerprint_dict = subsequence_model.load_pickle('../../Data/input/fingerprint_dict.pickle') |
|
atom_dict = subsequence_model.load_pickle('../../Data/input/atom_dict.pickle') |
|
bond_dict = subsequence_model.load_pickle('../../Data/input/bond_dict.pickle') |
|
word_dict = subsequence_model.load_pickle('../../Data/input/sequence_dict.pickle') |
|
n_fingerprint = len(fingerprint_dict) |
|
n_word = len(word_dict) |
|
|
|
|
|
|
|
radius=2 |
|
ngram=3 |
|
|
|
|
|
|
|
dim=10 |
|
layer_gnn=3 |
|
side=5 |
|
window=11 |
|
layer_cnn=3 |
|
layer_output=3 |
|
lr=1e-3 |
|
lr_decay=0.5 |
|
decay_interval=10 |
|
weight_decay=1e-6 |
|
iteration=100 |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
else: |
|
device = torch.device('cpu') |
|
|
|
|
|
Kcat_model = subsequence_model.KcatPrediction(device, n_fingerprint, n_word, 2*dim, layer_gnn, window, layer_cnn, layer_output).to(device) |
|
Kcat_model.load_state_dict(torch.load('../../Results/output/all--radius2--ngram3--dim20--layer_gnn3--window11--layer_cnn3--layer_output3--lr1e-3--lr_decay0.5--decay_interval10--weight_decay1e-6--iteration50', map_location=device)) |
|
|
|
|
|
predictor = Predictor(Kcat_model) |
|
|
|
print('It\'s time to start the prediction!') |
|
print('-----------------------------------') |
|
|
|
i = 0 |
|
alldata = dict() |
|
alldata['type'] = list() |
|
alldata['entry'] = list() |
|
alldata['weights'] = list() |
|
|
|
|
|
for wildtype_mutant_entry in wildtype_mutant_entries : |
|
entry_names = wildtype_mutant_entry[0].split('&') |
|
|
|
|
|
|
|
experimental_values = list() |
|
predicted_values = list() |
|
wildtype_like = list() |
|
wildtype_decreased = list() |
|
wildtype_like_positions = list() |
|
wildtype_decreased_positions = list() |
|
|
|
if entry_names[0] == 'Inosine' : |
|
print('This entry is:', entry_names) |
|
for data in Kcat_data : |
|
|
|
|
|
substrate = data['Substrate'] |
|
organism = data['Organism'] |
|
EC = data['ECNumber'] |
|
entry = substrate + '&' + organism + '&' + EC |
|
|
|
if entry == wildtype_mutant_entry[0] : |
|
wildtype_kcat = extract_wildtype_kcat(entry) |
|
wildtype_sequence = extract_wildtype_sequence(entry) |
|
wildtype_attention_profiles =extract_wildtype_attention(entry)[0] |
|
|
|
|
|
|
|
|
|
|
|
i += 1 |
|
|
|
smiles = data['Smiles'] |
|
sequence = data['Sequence'] |
|
enzyme_type = data['Type'] |
|
Kcat = data['Value'] |
|
if "." not in smiles and float(Kcat) > 0: |
|
|
|
|
|
|
|
mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) |
|
atoms = create_atoms(mol) |
|
|
|
i_jbond_dict = create_ijbonddict(mol) |
|
|
|
|
|
fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius) |
|
|
|
|
|
|
|
adjacency = create_adjacency(mol) |
|
|
|
|
|
|
|
words = split_sequence(sequence,ngram) |
|
|
|
|
|
|
|
fingerprints = torch.LongTensor(fingerprints) |
|
adjacency = torch.FloatTensor(adjacency) |
|
words = torch.LongTensor(words) |
|
|
|
inputs = [fingerprints, adjacency, words] |
|
|
|
value = float(data['Value']) |
|
|
|
normalized_value = value/wildtype_kcat |
|
|
|
|
|
|
|
experimental_values.append(math.log10(value)) |
|
|
|
prediction, attention_profiles = predictor.predict(inputs) |
|
|
|
|
|
|
|
different_positions = compare_mutant_wildtype_sequence(sequence, wildtype_sequence) |
|
|
|
entry_name = entry_names[0] |
|
if normalized_value >= 0.5 and normalized_value < 2.0 : |
|
|
|
for position in different_positions : |
|
wildtype_like_positions.append(position+1) |
|
wildtype_like.append(float(wildtype_attention_profiles[position])) |
|
alldata['type'].append('Wildtype_like') |
|
alldata['entry'].append(entry_name) |
|
alldata['weights'].append(float(wildtype_attention_profiles[position])) |
|
|
|
if normalized_value < 0.5 : |
|
|
|
for position in different_positions : |
|
wildtype_decreased_positions.append(position+1) |
|
wildtype_decreased.append(float(wildtype_attention_profiles[position])) |
|
alldata['type'].append('Wildtype_decreased') |
|
alldata['entry'].append(entry_name) |
|
alldata['weights'].append(float(wildtype_attention_profiles[position])) |
|
|
|
print('wildtype_like_positions:', wildtype_like_positions) |
|
print('wildtype_decreased_positions:', wildtype_decreased_positions) |
|
|
|
|
|
|
|
print('Attention weights in wildtype_like:', wildtype_like) |
|
print('Attention weights in wildtype_decreased:', wildtype_decreased) |
|
|
|
|
|
return wildtype_like_positions, wildtype_decreased_positions, wildtype_like, wildtype_decreased |
|
|
|
def main() : |
|
substrate, organism, EC = ('Inosine', 'Homo sapiens', '2.4.2.1') |
|
entry = substrate + '&' + organism + '&' + EC |
|
wildtype_attentions, sequence = extract_wildtype_attention(entry) |
|
|
|
wildtype_like_positions, wildtype_decreased_positions, wildtype_like, wildtype_decreased = wildtype_like_decreased_info() |
|
plot_attention_weights(wildtype_attentions, wildtype_like_positions, wildtype_decreased_positions, wildtype_like, wildtype_decreased) |
|
|
|
|
|
if __name__ == '__main__' : |
|
main() |
|
|