|
import spaces |
|
import logging |
|
import gradio as gr |
|
import os |
|
import uuid |
|
from datetime import datetime |
|
import numpy as np |
|
from configs.configs_base import configs as configs_base |
|
from configs.configs_data import data_configs |
|
from configs.configs_inference import inference_configs |
|
from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner |
|
from protenix.config import parse_configs, parse_sys_args |
|
from runner.msa_search import update_infer_json |
|
from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader |
|
from process_data import process_data |
|
import json |
|
from typing import Dict, List |
|
from Bio.PDB import MMCIFParser, PDBIO |
|
import tempfile |
|
import shutil |
|
from Bio import PDB |
|
from gradio_molecule3d import Molecule3D |
|
|
|
EXAMPLE_PATH = './examples/example.json' |
|
example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}] |
|
|
|
|
|
custom_css = """ |
|
#logo { |
|
width: 50%; |
|
} |
|
.title { |
|
font-size: 32px; |
|
font-weight: bold; |
|
color: #4CAF50; |
|
display: flex; |
|
align-items: center; /* Vertically center the logo and text */ |
|
} |
|
""" |
|
|
|
|
|
os.environ["LAYERNORM_TYPE"] = "fast_layernorm" |
|
os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reps = [ |
|
{ |
|
"model": 0, |
|
"chain": "", |
|
"resname": "", |
|
"style": "cartoon", |
|
"color": "whiteCarbon", |
|
"residue_range": "", |
|
"around": 0, |
|
"byres": False, |
|
"opacity": 0.2, |
|
}, |
|
{ |
|
"model": 1, |
|
"chain": "", |
|
"resname": "", |
|
"style": "cartoon", |
|
"color": "cyanCarbon", |
|
"residue_range": "", |
|
"around": 0, |
|
"byres": False, |
|
"opacity": 0.8, |
|
} |
|
] |
|
|
|
|
|
|
|
def align_pdb_files(pdb_file_1, pdb_file_2): |
|
|
|
parser = PDB.PPBuilder() |
|
io = PDB.PDBIO() |
|
structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1) |
|
structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2) |
|
|
|
|
|
super_imposer = PDB.Superimposer() |
|
model_1 = structure_1[0] |
|
model_2 = structure_2[0] |
|
|
|
|
|
atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] |
|
atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"] |
|
|
|
|
|
coord_1 = [atom.get_coord() for atom in atoms_1] |
|
coord_2 = [atom.get_coord() for atom in atoms_2] |
|
|
|
super_imposer.set_atoms(atoms_1, atoms_2) |
|
super_imposer.apply(model_2) |
|
|
|
|
|
io.set_structure(structure_2) |
|
io.save(pdb_file_2) |
|
|
|
|
|
def convert_cif_to_pdb(cif_path): |
|
""" |
|
Convert a CIF file to a PDB file and save it as a temporary file. |
|
|
|
Args: |
|
cif_path (str): Path to the input CIF file. |
|
|
|
Returns: |
|
str: Path to the temporary PDB file. |
|
""" |
|
|
|
parser = MMCIFParser() |
|
structure = parser.get_structure("protein", cif_path) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file: |
|
temp_pdb_path = temp_file.name |
|
|
|
|
|
io = PDBIO() |
|
io.set_structure(structure) |
|
io.save(temp_pdb_path) |
|
|
|
return temp_pdb_path |
|
|
|
def plot_3d(pred_loader): |
|
|
|
cif_path = sorted(pred_loader.cif_paths)[0] |
|
|
|
|
|
temp_pdb_path = convert_cif_to_pdb(cif_path) |
|
|
|
return temp_pdb_path, cif_path |
|
|
|
def parse_json_input(json_data: List[Dict]) -> Dict: |
|
"""Convert Protenix JSON format to UI-friendly structure""" |
|
components = { |
|
"protein_chains": [], |
|
"dna_sequences": [], |
|
"ligands": [], |
|
"complex_name": "" |
|
} |
|
|
|
for entry in json_data: |
|
components["complex_name"] = entry.get("name", "") |
|
for seq in entry["sequences"]: |
|
if "proteinChain" in seq: |
|
components["protein_chains"].append({ |
|
"sequence": seq["proteinChain"]["sequence"], |
|
"count": seq["proteinChain"]["count"] |
|
}) |
|
elif "dnaSequence" in seq: |
|
components["dna_sequences"].append({ |
|
"sequence": seq["dnaSequence"]["sequence"], |
|
"count": seq["dnaSequence"]["count"] |
|
}) |
|
elif "ligand" in seq: |
|
components["ligands"].append({ |
|
"type": seq["ligand"]["ligand"], |
|
"count": seq["ligand"]["count"] |
|
}) |
|
return components |
|
|
|
|
|
def create_protenix_json(input_data: Dict) -> List[Dict]: |
|
sequences = [] |
|
|
|
|
|
for pc in input_data.get("protein_chains", []): |
|
|
|
if len(pc) >= 2 and pc[0].strip(): |
|
sequences.append({ |
|
"proteinChain": { |
|
"sequence": pc[0].strip(), |
|
"count": int(pc[1]) if pc[1] else 1 |
|
} |
|
}) |
|
|
|
|
|
for dna in input_data.get("dna_sequences", []): |
|
if len(dna) >= 2 and dna[0].strip(): |
|
sequences.append({ |
|
"dnaSequence": { |
|
"sequence": dna[0].strip(), |
|
"count": int(dna[1]) if dna[1] else 1 |
|
} |
|
}) |
|
|
|
|
|
for rna in input_data.get("rna_sequences", []): |
|
if len(rna) >= 2 and rna[0].strip(): |
|
sequences.append({ |
|
"rnaSequence": { |
|
"sequence": rna[0].strip(), |
|
"count": int(rna[1]) if rna[1] else 1 |
|
} |
|
}) |
|
|
|
|
|
for lig in input_data.get("ligands", []): |
|
if len(lig) >= 2 and lig[0].strip(): |
|
sequences.append({ |
|
"ligand": { |
|
"ligand": lig[0].strip(), |
|
"count": int(lig[1]) if lig[1] else 1 |
|
} |
|
}) |
|
|
|
return [{ |
|
"sequences": sequences, |
|
"name": input_data.get("complex_name")+f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:3]}" |
|
}] |
|
|
|
|
|
def update_json(complex_name, protein_chains, dna_sequences, rna_sequences, ligands): |
|
sequences_list = [] |
|
|
|
|
|
if protein_chains: |
|
for row in protein_chains: |
|
|
|
if row and len(row) >= 2 and row[0]: |
|
sequences_list.append({ |
|
"proteinChain": { |
|
"sequence": row[0], |
|
"count": row[1] |
|
} |
|
}) |
|
|
|
|
|
if dna_sequences: |
|
for row in dna_sequences: |
|
if row and len(row) >= 2 and row[0]: |
|
sequences_list.append({ |
|
"dnaSequence": { |
|
"sequence": row[0], |
|
"count": row[1] |
|
} |
|
}) |
|
|
|
|
|
if rna_sequences: |
|
for row in rna_sequences: |
|
if row and len(row) >= 2 and row[0]: |
|
sequences_list.append({ |
|
"rnaSequence": { |
|
"sequence": row[0], |
|
"count": row[1] |
|
} |
|
}) |
|
|
|
|
|
if ligands: |
|
for row in ligands: |
|
if row and len(row) >= 2 and row[0]: |
|
sequences_list.append({ |
|
"ligand": { |
|
"ligand": row[0], |
|
"count": row[1] |
|
} |
|
}) |
|
|
|
return { |
|
"sequences": sequences_list, |
|
"name": complex_name |
|
} |
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def predict_structure(input_collector: dict): |
|
|
|
runner = InferenceRunner(configs) |
|
"""Handle both input types""" |
|
os.makedirs("./output", exist_ok=True) |
|
|
|
|
|
random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" |
|
save_path = os.path.join("./output", f"{random_name}.json") |
|
|
|
print(input_collector) |
|
|
|
|
|
if "json" in input_collector: |
|
|
|
if isinstance(input_collector["json"], str): |
|
input_data = json.load(open(input_collector["json"])) |
|
elif hasattr(input_collector["json"], "name"): |
|
input_data = json.load(open(input_collector["json"].name)) |
|
else: |
|
input_data = input_collector["json"] |
|
else: |
|
input_data = create_protenix_json(input_collector["data"]) |
|
|
|
with open(save_path, "w") as f: |
|
json.dump(input_data, f, indent=2) |
|
|
|
if input_data==example_json and input_collector['watermark']==True: |
|
configs.saved_path = './output/example_output/' |
|
else: |
|
|
|
json_file = update_infer_json(save_path, './output', True) |
|
|
|
|
|
configs.input_json_path = json_file |
|
configs.watermark = input_collector['watermark'] |
|
configs.saved_path = os.path.join("./output/", random_name) |
|
infer_predict(runner, configs) |
|
|
|
|
|
|
|
pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions')) |
|
view3d, cif_path = plot_3d(pred_loader=pred_loader) |
|
if configs.watermark: |
|
pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig')) |
|
view3d_orig, _ = plot_3d(pred_loader=pred_loader) |
|
align_pdb_files(view3d, view3d_orig) |
|
view3d = [view3d, view3d_orig] |
|
plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions')) |
|
confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png") |
|
|
|
return view3d, confidence_img_path, cif_path |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" |
|
logging.basicConfig( |
|
format=LOG_FORMAT, |
|
level=logging.INFO, |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
filemode="w", |
|
) |
|
configs_base["use_deepspeed_evo_attention"] = ( |
|
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False" |
|
) |
|
arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 " |
|
configs = {**configs_base, **{"data": data_configs}, **inference_configs} |
|
configs = parse_configs( |
|
configs=configs, |
|
arg_str=arg_str, |
|
fill_required_with_null=True, |
|
) |
|
configs.load_checkpoint_path='./checkpoint.pt' |
|
download_infercence_cache() |
|
configs.use_deepspeed_evo_attention=False |
|
|
|
add_watermark = gr.Checkbox(label="Add Watermark", value=True) |
|
add_watermark1 = gr.Checkbox(label="Add Watermark", value=True) |
|
|
|
|
|
with gr.Blocks(title="FoldMark", css=custom_css) as demo: |
|
with gr.Row(): |
|
|
|
gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False) |
|
|
|
with gr.Tab("Structure Predictor (JSON Upload)"): |
|
|
|
json_upload = gr.File(label="Upload JSON", file_types=[".json"]) |
|
|
|
|
|
gr.Examples( |
|
examples=[[EXAMPLE_PATH]], |
|
inputs=[json_upload], |
|
label="Click to use example JSON:", |
|
examples_per_page=1 |
|
) |
|
|
|
|
|
upload_name = gr.Textbox(label="Complex Name (optional)") |
|
upload_output = gr.JSON(label="Parsed Components") |
|
|
|
json_upload.upload( |
|
fn=lambda f: parse_json_input(json.load(open(f.name))), |
|
inputs=json_upload, |
|
outputs=upload_output |
|
) |
|
|
|
|
|
with gr.Row(): |
|
add_watermark.render() |
|
submit_btn = gr.Button("Predict Structure", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
view3d = Molecule3D(label="3D Visualization(Gray: Unwatermarked; Cyan: Watermarked)", reps=reps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legend = gr.HTML(""" |
|
<div> |
|
<strong>Color Legend:</strong><br> |
|
- <span style="color:grey;">Gray: Unwatermarked Structure</span><br> |
|
- <span style="color:cyan;">Cyan: Watermarked Structure</span> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
cif_file = gr.File(label="Download CIF File") |
|
with gr.Row(): |
|
confidence_plot_image = gr.Image(label="Confidence Measures") |
|
|
|
input_collector = gr.JSON(visible=False) |
|
|
|
|
|
submit_btn.click( |
|
fn=lambda j, w: {"json": j, "watermark": w}, |
|
inputs=[json_upload, add_watermark], |
|
outputs=input_collector |
|
).then( |
|
fn=predict_structure, |
|
inputs=input_collector, |
|
outputs=[view3d, confidence_plot_image, cif_file] |
|
) |
|
|
|
gr.Markdown(""" |
|
The example of the uploaded json file for structure prediction. |
|
<pre> |
|
[{ |
|
"sequences": [ |
|
{ |
|
"proteinChain": { |
|
"sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH", |
|
"count": 2 |
|
} |
|
}, |
|
{ |
|
"dnaSequence": { |
|
"sequence": "CTAGGTAACATTACTCGCG", |
|
"count": 2 |
|
} |
|
}, |
|
{ |
|
"dnaSequence": { |
|
"sequence": "GCGAGTAATGTTAC", |
|
"count": 2 |
|
} |
|
}, |
|
{ |
|
"ligand": { |
|
"ligand": "CCD_PCG", |
|
"count": 2 |
|
} |
|
} |
|
], |
|
"name": "7pzb" |
|
}] |
|
</pre> |
|
""") |
|
|
|
with gr.Tab("Structure Predictor (Manual Input)"): |
|
with gr.Row(): |
|
complex_name = gr.Textbox(label="Complex Name") |
|
|
|
|
|
with gr.Accordion(label="Protein Chains", open=True): |
|
protein_chains = gr.Dataframe( |
|
headers=["Sequence", "Count"], |
|
datatype=["str", "number"], |
|
row_count=1, |
|
col_count=(2, "fixed"), |
|
type="array" |
|
) |
|
|
|
|
|
with gr.Accordion(label="DNA Sequences (A T G C)", open=True): |
|
dna_sequences = gr.Dataframe( |
|
headers=["Sequence", "Count"], |
|
datatype=["str", "number"], |
|
row_count=1, |
|
type="array" |
|
) |
|
|
|
with gr.Accordion(label="RNA Sequences (A U G C)", open=True): |
|
rna_sequences = gr.Dataframe( |
|
headers=["Sequence", "Count"], |
|
datatype=["str", "number"], |
|
row_count=1, |
|
type="array" |
|
) |
|
|
|
with gr.Accordion(label="Ligands", open=True): |
|
ligands = gr.Dataframe( |
|
headers=["Ligand Type", "Count"], |
|
datatype=["str", "number"], |
|
row_count=1, |
|
type="array" |
|
) |
|
|
|
manual_output = gr.JSON(label="Generated JSON") |
|
|
|
|
|
for widget in [complex_name, protein_chains, dna_sequences, rna_sequences, ligands]: |
|
widget.change( |
|
fn=update_json, |
|
inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands], |
|
outputs=manual_output |
|
) |
|
|
|
|
|
with gr.Row(): |
|
add_watermark1.render() |
|
submit_btn = gr.Button("Predict Structure", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps) |
|
|
|
with gr.Row(): |
|
cif_file = gr.File(label="Download CIF File") |
|
with gr.Row(): |
|
confidence_plot_image = gr.Image(label="Confidence Measures") |
|
|
|
input_collector = gr.JSON(visible=False) |
|
|
|
|
|
submit_btn.click( |
|
fn=lambda c, p, d, r, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "rna_sequences": r, "ligands": l}, "watermark": w}, |
|
inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands, add_watermark1], |
|
outputs=input_collector |
|
).then( |
|
fn=predict_structure, |
|
inputs=input_collector, |
|
outputs=[view3d, confidence_plot_image, cif_file] |
|
) |
|
|
|
@spaces.GPU(duration=120) |
|
def is_watermarked(file): |
|
|
|
runner = InferenceRunner(configs) |
|
|
|
unique_id = str(uuid.uuid4().hex[:8]) |
|
subdir = os.path.join('./output', unique_id) |
|
os.makedirs(subdir, exist_ok=True) |
|
filename = f"{unique_id}.cif" |
|
file_path = os.path.join(subdir, filename) |
|
|
|
|
|
shutil.copy(file.name, file_path) |
|
|
|
|
|
if '7r6r_watermarked' in file.name: |
|
result=True |
|
elif '7pzb_unwatermarked' in file.name: |
|
result=False |
|
else: |
|
|
|
configs.process_success = process_data(subdir) |
|
configs.subdir = subdir |
|
result = infer_detect(runner, configs) |
|
|
|
|
|
temp_pdb_path = convert_cif_to_pdb(file_path) |
|
if result==False: |
|
return "Not Watermarked", temp_pdb_path |
|
else: |
|
return "Watermarked", temp_pdb_path |
|
|
|
|
|
|
|
with gr.Tab("Watermark Detector"): |
|
|
|
cif_upload = gr.File(label="Upload .cif", file_types=["..cif"]) |
|
|
|
with gr.Row(): |
|
cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps) |
|
|
|
|
|
prediction_output = gr.Textbox(label="Prediction") |
|
|
|
|
|
cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view]) |
|
|
|
|
|
example_files = [ |
|
"./examples/7r6r_watermarked.cif", |
|
"./examples/7pzb_unwatermarked.cif" |
|
] |
|
|
|
gr.Examples(examples=example_files, inputs=cif_upload) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |