import spaces import logging import gradio as gr import os import uuid from datetime import datetime import numpy as np from configs.configs_base import configs as configs_base from configs.configs_data import data_configs from configs.configs_inference import inference_configs from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner from protenix.config import parse_configs, parse_sys_args from runner.msa_search import update_infer_json from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader from process_data import process_data import json from typing import Dict, List from Bio.PDB import MMCIFParser, PDBIO import tempfile import shutil from Bio import PDB from gradio_molecule3d import Molecule3D EXAMPLE_PATH = './examples/example.json' example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}] # Custom CSS for styling custom_css = """ #logo { width: 50%; } .title { font-size: 32px; font-weight: bold; color: #4CAF50; display: flex; align-items: center; /* Vertically center the logo and text */ } """ os.environ["LAYERNORM_TYPE"] = "fast_layernorm" os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False" # Set environment variable in the script #os.environ['CUTLASS_PATH'] = './cutlass' # reps = [ # { # "model": 0, # "chain": "", # "resname": "", # "style": "cartoon", # Use cartoon style # "color": "whiteCarbon", # "residue_range": "", # "around": 0, # "byres": False, # "visible": True # Ensure this representation is visible # } # ] reps = [ { "model": 0, "chain": "", "resname": "", "style": "cartoon", "color": "whiteCarbon", "residue_range": "", "around": 0, "byres": False, "opacity": 0.2, }, { "model": 1, "chain": "", "resname": "", "style": "cartoon", "color": "cyanCarbon", "residue_range": "", "around": 0, "byres": False, "opacity": 0.8, } ] ## def align_pdb_files(pdb_file_1, pdb_file_2): # Load the structures parser = PDB.PPBuilder() io = PDB.PDBIO() structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1) structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2) # Superimpose the second structure onto the first super_imposer = PDB.Superimposer() model_1 = structure_1[0] model_2 = structure_2[0] # Extract the coordinates from the two structures atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"] # Align the structures based on the CA atoms coord_1 = [atom.get_coord() for atom in atoms_1] coord_2 = [atom.get_coord() for atom in atoms_2] super_imposer.set_atoms(atoms_1, atoms_2) super_imposer.apply(model_2) # Apply the transformation to model_2 # Save the aligned structure back to the original file io.set_structure(structure_2) # Save the aligned structure to the second file (original file) io.save(pdb_file_2) # Function to convert .cif to .pdb and save as a temporary file def convert_cif_to_pdb(cif_path): """ Convert a CIF file to a PDB file and save it as a temporary file. Args: cif_path (str): Path to the input CIF file. Returns: str: Path to the temporary PDB file. """ # Initialize the MMCIF parser parser = MMCIFParser() structure = parser.get_structure("protein", cif_path) # Create a temporary file for the PDB output with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file: temp_pdb_path = temp_file.name # Save the structure as a PDB file io = PDBIO() io.set_structure(structure) io.save(temp_pdb_path) return temp_pdb_path def plot_3d(pred_loader): # Get the CIF file path for the given prediction ID cif_path = sorted(pred_loader.cif_paths)[0] # Convert the CIF file to a temporary PDB file temp_pdb_path = convert_cif_to_pdb(cif_path) return temp_pdb_path, cif_path def parse_json_input(json_data: List[Dict]) -> Dict: """Convert Protenix JSON format to UI-friendly structure""" components = { "protein_chains": [], "dna_sequences": [], "ligands": [], "complex_name": "" } for entry in json_data: components["complex_name"] = entry.get("name", "") for seq in entry["sequences"]: if "proteinChain" in seq: components["protein_chains"].append({ "sequence": seq["proteinChain"]["sequence"], "count": seq["proteinChain"]["count"] }) elif "dnaSequence" in seq: components["dna_sequences"].append({ "sequence": seq["dnaSequence"]["sequence"], "count": seq["dnaSequence"]["count"] }) elif "ligand" in seq: components["ligands"].append({ "type": seq["ligand"]["ligand"], "count": seq["ligand"]["count"] }) return components def create_protenix_json(input_data: Dict) -> List[Dict]: sequences = [] # Process protein chains for pc in input_data.get("protein_chains", []): # Check that the row has both columns and the sequence is nonempty. if len(pc) >= 2 and pc[0].strip(): sequences.append({ "proteinChain": { "sequence": pc[0].strip(), "count": int(pc[1]) if pc[1] else 1 } }) # Process DNA sequences for dna in input_data.get("dna_sequences", []): if len(dna) >= 2 and dna[0].strip(): sequences.append({ "dnaSequence": { "sequence": dna[0].strip(), "count": int(dna[1]) if dna[1] else 1 } }) # Process RNA sequences for rna in input_data.get("rna_sequences", []): if len(rna) >= 2 and rna[0].strip(): sequences.append({ "rnaSequence": { "sequence": rna[0].strip(), "count": int(rna[1]) if rna[1] else 1 } }) # Process ligands for lig in input_data.get("ligands", []): if len(lig) >= 2 and lig[0].strip(): sequences.append({ "ligand": { "ligand": lig[0].strip(), "count": int(lig[1]) if lig[1] else 1 } }) return [{ "sequences": sequences, "name": input_data.get("complex_name")+f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:3]}" }] def update_json(complex_name, protein_chains, dna_sequences, rna_sequences, ligands): sequences_list = [] # Process protein chains (DataFrame with headers: ["Sequence", "Count"]) if protein_chains: for row in protein_chains: # Check if the row is valid and non-empty if row and len(row) >= 2 and row[0]: sequences_list.append({ "proteinChain": { "sequence": row[0], "count": row[1] } }) # Process DNA sequences if dna_sequences: for row in dna_sequences: if row and len(row) >= 2 and row[0]: sequences_list.append({ "dnaSequence": { "sequence": row[0], "count": row[1] } }) # Process RNA sequences if rna_sequences: for row in rna_sequences: if row and len(row) >= 2 and row[0]: sequences_list.append({ "rnaSequence": { "sequence": row[0], "count": row[1] } }) # Process ligands (DataFrame with headers: ["Ligand Type", "Count"]) if ligands: for row in ligands: if row and len(row) >= 2 and row[0]: sequences_list.append({ "ligand": { "ligand": row[0], "count": row[1] } }) return { "sequences": sequences_list, "name": complex_name } #@torch.inference_mode() @spaces.GPU(duration=180) # Specify a duration to avoid timeout def predict_structure(input_collector: dict): #first initialize runner runner = InferenceRunner(configs) """Handle both input types""" os.makedirs("./output", exist_ok=True) # Generate random filename with timestamp random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" save_path = os.path.join("./output", f"{random_name}.json") print(input_collector) # Handle JSON input if "json" in input_collector: # Handle different input types if isinstance(input_collector["json"], str): # Example JSON case (file path) input_data = json.load(open(input_collector["json"])) elif hasattr(input_collector["json"], "name"): # File upload case input_data = json.load(open(input_collector["json"].name)) else: # Direct JSON data case input_data = input_collector["json"] else: # Manual input case input_data = create_protenix_json(input_collector["data"]) with open(save_path, "w") as f: json.dump(input_data, f, indent=2) if input_data==example_json and input_collector['watermark']==True: configs.saved_path = './output/example_output/' else: # run msa json_file = update_infer_json(save_path, './output', True) # Run prediction configs.input_json_path = json_file configs.watermark = input_collector['watermark'] configs.saved_path = os.path.join("./output/", random_name) infer_predict(runner, configs) #saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions') # Generate visualizations pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions')) view3d, cif_path = plot_3d(pred_loader=pred_loader) if configs.watermark: pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig')) view3d_orig, _ = plot_3d(pred_loader=pred_loader) align_pdb_files(view3d, view3d_orig) view3d = [view3d, view3d_orig] plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions')) confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png") return view3d, confidence_img_path, cif_path logger = logging.getLogger(__name__) LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" logging.basicConfig( format=LOG_FORMAT, level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", filemode="w", ) configs_base["use_deepspeed_evo_attention"] = ( os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False" ) arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 " configs = {**configs_base, **{"data": data_configs}, **inference_configs} configs = parse_configs( configs=configs, arg_str=arg_str, fill_required_with_null=True, ) configs.load_checkpoint_path='./checkpoint.pt' download_infercence_cache() configs.use_deepspeed_evo_attention=False add_watermark = gr.Checkbox(label="Add Watermark", value=True) add_watermark1 = gr.Checkbox(label="Add Watermark", value=True) with gr.Blocks(title="FoldMark", css=custom_css) as demo: with gr.Row(): # Use a Column to align the logo and title horizontally gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False) with gr.Tab("Structure Predictor (JSON Upload)"): # First create the upload component json_upload = gr.File(label="Upload JSON", file_types=[".json"]) # Then create the example component that references it gr.Examples( examples=[[EXAMPLE_PATH]], inputs=[json_upload], label="Click to use example JSON:", examples_per_page=1 ) # Rest of the components upload_name = gr.Textbox(label="Complex Name (optional)") upload_output = gr.JSON(label="Parsed Components") json_upload.upload( fn=lambda f: parse_json_input(json.load(open(f.name))), inputs=json_upload, outputs=upload_output ) # Shared prediction components with gr.Row(): add_watermark.render() submit_btn = gr.Button("Predict Structure", variant="primary") #structure_view = gr.HTML(label="3D Visualization") with gr.Row(): view3d = Molecule3D(label="3D Visualization(Gray: Unwatermarked; Cyan: Watermarked)", reps=reps) # legend = gr.Markdown(""" # **Color Legend:** # - Gray: Unwatermarked Structure # - Cyan: Watermarked Structure # """) legend = gr.HTML("""
[{ "sequences": [ { "proteinChain": { "sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH", "count": 2 } }, { "dnaSequence": { "sequence": "CTAGGTAACATTACTCGCG", "count": 2 } }, { "dnaSequence": { "sequence": "GCGAGTAATGTTAC", "count": 2 } }, { "ligand": { "ligand": "CCD_PCG", "count": 2 } } ], "name": "7pzb" }]""") with gr.Tab("Structure Predictor (Manual Input)"): with gr.Row(): complex_name = gr.Textbox(label="Complex Name") # Replace gr.Group with gr.Accordion with gr.Accordion(label="Protein Chains", open=True): protein_chains = gr.Dataframe( headers=["Sequence", "Count"], datatype=["str", "number"], row_count=1, col_count=(2, "fixed"), type="array" ) # Repeat for other groups with gr.Accordion(label="DNA Sequences (A T G C)", open=True): dna_sequences = gr.Dataframe( headers=["Sequence", "Count"], datatype=["str", "number"], row_count=1, type="array" ) with gr.Accordion(label="RNA Sequences (A U G C)", open=True): rna_sequences = gr.Dataframe( headers=["Sequence", "Count"], datatype=["str", "number"], row_count=1, type="array" ) with gr.Accordion(label="Ligands", open=True): ligands = gr.Dataframe( headers=["Ligand Type", "Count"], datatype=["str", "number"], row_count=1, type="array" ) manual_output = gr.JSON(label="Generated JSON") # Attach a change event to all widgets so that any change updates the JSON output. for widget in [complex_name, protein_chains, dna_sequences, rna_sequences, ligands]: widget.change( fn=update_json, inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands], outputs=manual_output ) # Shared prediction components with gr.Row(): add_watermark1.render() submit_btn = gr.Button("Predict Structure", variant="primary") #structure_view = gr.HTML(label="3D Visualization") with gr.Row(): view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps) with gr.Row(): cif_file = gr.File(label="Download CIF File") with gr.Row(): confidence_plot_image = gr.Image(label="Confidence Measures") input_collector = gr.JSON(visible=False) # Map inputs to a dictionary submit_btn.click( fn=lambda c, p, d, r, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "rna_sequences": r, "ligands": l}, "watermark": w}, inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands, add_watermark1], outputs=input_collector ).then( fn=predict_structure, inputs=input_collector, outputs=[view3d, confidence_plot_image, cif_file] ) @spaces.GPU(duration=120) def is_watermarked(file): #first initialize runner runner = InferenceRunner(configs) # Generate a unique subdirectory and filename unique_id = str(uuid.uuid4().hex[:8]) subdir = os.path.join('./output', unique_id) os.makedirs(subdir, exist_ok=True) filename = f"{unique_id}.cif" file_path = os.path.join(subdir, filename) # Save the uploaded file to the new location shutil.copy(file.name, file_path) #just for fast demonstration, otherwise it takes around 100 seconds if '7r6r_watermarked' in file.name: result=True elif '7pzb_unwatermarked' in file.name: result=False else: # Call your processing functions configs.process_success = process_data(subdir) configs.subdir = subdir result = infer_detect(runner, configs) # This function should return 'Watermarked' or 'Not Watermarked' temp_pdb_path = convert_cif_to_pdb(file_path) if result==False: return "Not Watermarked", temp_pdb_path else: return "Watermarked", temp_pdb_path with gr.Tab("Watermark Detector"): # First create the upload component cif_upload = gr.File(label="Upload .cif", file_types=["..cif"]) with gr.Row(): cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps) # Prediction output prediction_output = gr.Textbox(label="Prediction") # Define the interaction cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view]) # Example files example_files = [ "./examples/7r6r_watermarked.cif", "./examples/7pzb_unwatermarked.cif" ] gr.Examples(examples=example_files, inputs=cif_upload) if __name__ == "__main__": demo.launch(share=True)