| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import gradio as gr | 
					
					
						
						| 
							 | 
						import py3Dmol | 
					
					
						
						| 
							 | 
						from Bio.PDB import * | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						from Bio.PDB import PDBParser | 
					
					
						
						| 
							 | 
						import pandas as pd | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						from MDmodel import GNN_MD | 
					
					
						
						| 
							 | 
						import h5py | 
					
					
						
						| 
							 | 
						from transformMD import GNNTransformMD | 
					
					
						
						| 
							 | 
						import sys | 
					
					
						
						| 
							 | 
						import pytraj as pt | 
					
					
						
						| 
							 | 
						import pickle  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						resid_hover = """function(atom,viewer) {{ | 
					
					
						
						| 
							 | 
						    if(!atom.label) {{ | 
					
					
						
						| 
							 | 
						        atom.label = viewer.addLabel('{0}:'+atom.atom+atom.serial, | 
					
					
						
						| 
							 | 
						            {{position: atom, backgroundColor: 'mintcream', fontColor:'black'}}); | 
					
					
						
						| 
							 | 
						    }} | 
					
					
						
						| 
							 | 
						}}""" | 
					
					
						
						| 
							 | 
						hover_func = """ | 
					
					
						
						| 
							 | 
						function(atom,viewer) { | 
					
					
						
						| 
							 | 
						    if(!atom.label) { | 
					
					
						
						| 
							 | 
						        atom.label = viewer.addLabel(atom.interaction, | 
					
					
						
						| 
							 | 
						            {position: atom, backgroundColor: 'black', fontColor:'white'}); | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						}""" | 
					
					
						
						| 
							 | 
						unhover_func = """ | 
					
					
						
						| 
							 | 
						function(atom,viewer) { | 
					
					
						
						| 
							 | 
						    if(atom.label) { | 
					
					
						
						| 
							 | 
						        viewer.removeLabel(atom.label); | 
					
					
						
						| 
							 | 
						        delete atom.label; | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						}""" | 
					
					
						
						| 
							 | 
						atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						model = GNN_MD(11, 64) | 
					
					
						
						| 
							 | 
						state_dict = torch.load( | 
					
					
						
						| 
							 | 
						    "best_weights_rep0.pt", | 
					
					
						
						| 
							 | 
						    map_location=torch.device("cpu"), | 
					
					
						
						| 
							 | 
						)["model_state_dict"] | 
					
					
						
						| 
							 | 
						model.load_state_dict(state_dict) | 
					
					
						
						| 
							 | 
						model = model.to('cpu') | 
					
					
						
						| 
							 | 
						model.eval() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def run_leap(fileName, path): | 
					
					
						
						| 
							 | 
						    leapText = """ | 
					
					
						
						| 
							 | 
						    source leaprc.protein.ff14SB | 
					
					
						
						| 
							 | 
						    source leaprc.water.tip3p | 
					
					
						
						| 
							 | 
						    exp = loadpdb PATH4amb.pdb | 
					
					
						
						| 
							 | 
						    saveamberparm exp PATHexp.top PATHexp.crd | 
					
					
						
						| 
							 | 
						    quit | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    with open(path+"leap.in", "w") as outLeap: | 
					
					
						
						| 
							 | 
						        outLeap.write(leapText.replace('PATH', path))	 | 
					
					
						
						| 
							 | 
						    os.system("tleap -f "+path+"leap.in >> "+path+"leap.out") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def convert_to_amber_format(pdbName): | 
					
					
						
						| 
							 | 
						    fileName, path = pdbName+'.pdb', '' | 
					
					
						
						| 
							 | 
						    os.system("pdb4amber -i "+fileName+" -p -y -o "+path+"4amb.pdb -l "+path+"pdb4amber_protein.log") | 
					
					
						
						| 
							 | 
						    run_leap(fileName, path) | 
					
					
						
						| 
							 | 
						    traj = pt.iterload(path+'exp.crd', top = path+'exp.top') | 
					
					
						
						| 
							 | 
						    pt.write_traj(path+fileName, traj, overwrite= True) | 
					
					
						
						| 
							 | 
						    print(path+fileName+' was created. Please always use this file for inspection because the coordinates might get translated during amber file generation and thus might vary from the input pdb file.') | 
					
					
						
						| 
							 | 
						    return pt.iterload(path+'exp.crd', top = path+'exp.top') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_maps(mapPath): | 
					
					
						
						| 
							 | 
						    residueMap = pickle.load(open(mapPath+'atoms_residue_map_generate.pickle','rb')) | 
					
					
						
						| 
							 | 
						    nameMap = pickle.load(open(mapPath+'atoms_name_map_generate.pickle','rb')) | 
					
					
						
						| 
							 | 
						    typeMap = pickle.load(open(mapPath+'atoms_type_map_generate.pickle','rb')) | 
					
					
						
						| 
							 | 
						    elementMap = pickle.load(open(mapPath+'map_atomType_element_numbers.pickle','rb')) | 
					
					
						
						| 
							 | 
						    return residueMap, nameMap, typeMap, elementMap | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_residues_atomwise(residues): | 
					
					
						
						| 
							 | 
						    atomwise = [] | 
					
					
						
						| 
							 | 
						    for name, nAtoms in residues: | 
					
					
						
						| 
							 | 
						        for i in range(nAtoms): | 
					
					
						
						| 
							 | 
						            atomwise.append(name) | 
					
					
						
						| 
							 | 
						    return atomwise | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_begin_atom_index(traj): | 
					
					
						
						| 
							 | 
						    natoms = [m.n_atoms for m in traj.top.mols] | 
					
					
						
						| 
							 | 
						    molecule_begin_atom_index = [0]  | 
					
					
						
						| 
							 | 
						    x = 0 | 
					
					
						
						| 
							 | 
						    for i in range(len(natoms)): | 
					
					
						
						| 
							 | 
						        x += natoms[i] | 
					
					
						
						| 
							 | 
						        molecule_begin_atom_index.append(x) | 
					
					
						
						| 
							 | 
						    print('molecule begin atom index', molecule_begin_atom_index, natoms) | 
					
					
						
						| 
							 | 
						    return molecule_begin_atom_index | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_traj_info(traj, mapPath): | 
					
					
						
						| 
							 | 
						    coordinates  = traj.xyz | 
					
					
						
						| 
							 | 
						    residueMap, nameMap, typeMap, elementMap = get_maps(mapPath) | 
					
					
						
						| 
							 | 
						    types = [typeMap[a.type] for a in traj.top.atoms] | 
					
					
						
						| 
							 | 
						    elements = [elementMap[typ] for typ in types] | 
					
					
						
						| 
							 | 
						    atomic_numbers = [a.atomic_number for a in traj.top.atoms] | 
					
					
						
						| 
							 | 
						    molecule_begin_atom_index = get_begin_atom_index(traj) | 
					
					
						
						| 
							 | 
						    residues = [(residueMap[res.name], res.n_atoms) for res in traj.top.residues] | 
					
					
						
						| 
							 | 
						    residues_atomwise = get_residues_atomwise(residues) | 
					
					
						
						| 
							 | 
						    return coordinates[0], elements, types, atomic_numbers, residues_atomwise, molecule_begin_atom_index | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def write_h5_info(outName, struct, atoms_type, atoms_number, atoms_residue, atoms_element, molecules_begin_atom_index, atoms_coordinates_ref): | 
					
					
						
						| 
							 | 
						    if os.path.isfile(outName): | 
					
					
						
						| 
							 | 
						        os.remove(outName) | 
					
					
						
						| 
							 | 
						    with h5py.File(outName, 'w') as oF: | 
					
					
						
						| 
							 | 
						        subgroup = oF.create_group(struct)      | 
					
					
						
						| 
							 | 
						        subgroup.create_dataset('atoms_residue', data= atoms_residue, compression = "gzip", dtype='i8') | 
					
					
						
						| 
							 | 
						        subgroup.create_dataset('molecules_begin_atom_index', data= molecules_begin_atom_index, compression = "gzip", dtype='i8') | 
					
					
						
						| 
							 | 
						        subgroup.create_dataset('atoms_type', data= atoms_type, compression = "gzip", dtype='i8') | 
					
					
						
						| 
							 | 
						        subgroup.create_dataset('atoms_number', data= atoms_number, compression = "gzip", dtype='i8')   | 
					
					
						
						| 
							 | 
						        subgroup.create_dataset('atoms_element', data= atoms_element, compression = "gzip", dtype='i8') | 
					
					
						
						| 
							 | 
						        subgroup.create_dataset('atoms_coordinates_ref', data= atoms_coordinates_ref, compression = "gzip", dtype='f8') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def preprocess(pdbid: str = None, ouputfile: str = "inference_for_md.hdf5", mask: str = "!@H=", mappath: str = "/maps/"): | 
					
					
						
						| 
							 | 
						    traj = convert_to_amber_format(pdbid) | 
					
					
						
						| 
							 | 
						    atoms_coordinates_ref, atoms_element, atoms_type, atoms_number, atoms_residue, molecules_begin_atom_index = get_traj_info(traj[mask], mappath) | 
					
					
						
						| 
							 | 
						    write_h5_info(ouputfile, pdbid, atoms_type, atoms_number, atoms_residue, atoms_element, molecules_begin_atom_index, atoms_coordinates_ref) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_pdb(pdb_code="", filepath=""): | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        return filepath.name | 
					
					
						
						| 
							 | 
						    except AttributeError as e: | 
					
					
						
						| 
							 | 
						        if pdb_code is None or pdb_code == "": | 
					
					
						
						| 
							 | 
						            return None | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb") | 
					
					
						
						| 
							 | 
						            return f"{pdb_code}.pdb" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_offset(pdb): | 
					
					
						
						| 
							 | 
						    pdb_multiline = pdb.split("\n") | 
					
					
						
						| 
							 | 
						    for line in pdb_multiline: | 
					
					
						
						| 
							 | 
						        if line.startswith("ATOM"): | 
					
					
						
						| 
							 | 
						            return int(line[22:27]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_pdbid_from_filename(filename: str): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return filename.split(".")[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def predict(pdb_code, pdb_file): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    pdbid = get_pdbid_from_filename(pdb_file.name) | 
					
					
						
						| 
							 | 
						    mdh5_file = "inference_for_md.hdf5" | 
					
					
						
						| 
							 | 
						    mappath = "/maps" | 
					
					
						
						| 
							 | 
						    mask = "!@H=" | 
					
					
						
						| 
							 | 
						    preprocess(pdbid=pdbid, ouputfile=mdh5_file, mask=mask, mappath=mappath) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    md_H5File = h5py.File(mdh5_file) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    column_names = ["x", "y", "z", "element"] | 
					
					
						
						| 
							 | 
						    atoms_protein = pd.DataFrame(columns = column_names) | 
					
					
						
						| 
							 | 
						    cutoff = md_H5File[pdbid]["molecules_begin_atom_index"][:][-1]  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    atoms_protein["x"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 0] | 
					
					
						
						| 
							 | 
						    atoms_protein["y"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 1] | 
					
					
						
						| 
							 | 
						    atoms_protein["z"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 2] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    atoms_protein["element"] = md_H5File[pdbid]["atoms_element"][:][:cutoff]   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    item = {} | 
					
					
						
						| 
							 | 
						    item["scores"] = 0 | 
					
					
						
						| 
							 | 
						    item["id"] = pdbid | 
					
					
						
						| 
							 | 
						    item["atoms_protein"] = atoms_protein | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    transform = GNNTransformMD() | 
					
					
						
						| 
							 | 
						    data_item = transform(item) | 
					
					
						
						| 
							 | 
						    adaptability = model(data_item) | 
					
					
						
						| 
							 | 
						    adaptability = adaptability.detach().numpy() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    data = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for i in range(adaptability.shape[0]): | 
					
					
						
						| 
							 | 
						        data.append([i, atom_mapping[atoms_protein.iloc[i, atoms_protein.columns.get_loc("element")] - 1], atoms_protein.iloc[i, atoms_protein.columns.get_loc("x")],atoms_protein.iloc[i, atoms_protein.columns.get_loc("y")],atoms_protein.iloc[i, atoms_protein.columns.get_loc("z")],adaptability[i]]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    topN = 100 | 
					
					
						
						| 
							 | 
						    topN_ind = np.argsort(adaptability)[::-1][:topN]     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    pdb = open(pdb_file.name, "r").read() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    view = py3Dmol.view(width=600, height=400) | 
					
					
						
						| 
							 | 
						    view.setBackgroundColor('white') | 
					
					
						
						| 
							 | 
						    view.addModel(pdb, "pdb") | 
					
					
						
						| 
							 | 
						    view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': 'turquoise'}}}) | 
					
					
						
						| 
							 | 
						    | 
					
					
						
						| 
							 | 
						    for i in range(topN): | 
					
					
						
						| 
							 | 
						        view.addSphere({'center':{'x':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("x")], 'y':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("y")],'z':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("z")]},'radius':adaptability[topN_ind[i]]/1.5,'color':'orange','alpha':0.75})     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    view.zoomTo() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    output = view._make_html().replace("'", '"') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    x = f"""<!DOCTYPE html><html> {output} </html>"""   | 
					
					
						
						| 
							 | 
						    return f"""<iframe  style="width: 100%; height:420px" name="result" allow="midi; geolocation; microphone; camera;  | 
					
					
						
						| 
							 | 
						    display-capture; encrypted-media;" sandbox="allow-modals allow-forms  | 
					
					
						
						| 
							 | 
						    allow-scripts allow-same-origin allow-popups  | 
					
					
						
						| 
							 | 
						    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""  | 
					
					
						
						| 
							 | 
						    allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""", pd.DataFrame(data, columns=['index','element','x','y','z','Adaptability']) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						callback = gr.CSVLogger() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def run(): | 
					
					
						
						| 
							 | 
						    with gr.Blocks() as demo: | 
					
					
						
						| 
							 | 
						        gr.Markdown("# Protein Adaptability Prediction") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        inp = gr.Textbox(placeholder="PDB Code or upload file below", label="Input structure") | 
					
					
						
						| 
							 | 
						        pdb_file = gr.File(label="PDB File Upload") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        single_btn = gr.Button(label="Run") | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            html = gr.HTML() | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            dataframe = gr.Dataframe() | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						        single_btn.click(fn=predict, inputs=[inp, pdb_file], outputs=[html, dataframe]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    demo.launch(server_name="0.0.0.0", server_port=7860) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    run() | 
					
					
						
						| 
							 | 
						
 |