FoldMark / app.py
Zaixi's picture
Update app.py
4d252c9 verified
import spaces
import logging
import gradio as gr
import os
import uuid
from datetime import datetime
import numpy as np
from configs.configs_base import configs as configs_base
from configs.configs_data import data_configs
from configs.configs_inference import inference_configs
from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner
from protenix.config import parse_configs, parse_sys_args
from runner.msa_search import update_infer_json
from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader
from process_data import process_data
import json
from typing import Dict, List
from Bio.PDB import MMCIFParser, PDBIO
import tempfile
import shutil
from Bio import PDB
from gradio_molecule3d import Molecule3D
EXAMPLE_PATH = './examples/example.json'
example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}]
# Custom CSS for styling
custom_css = """
#logo {
width: 50%;
}
.title {
font-size: 32px;
font-weight: bold;
color: #4CAF50;
display: flex;
align-items: center; /* Vertically center the logo and text */
}
"""
os.environ["LAYERNORM_TYPE"] = "fast_layernorm"
os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False"
# Set environment variable in the script
#os.environ['CUTLASS_PATH'] = './cutlass'
# reps = [
# {
# "model": 0,
# "chain": "",
# "resname": "",
# "style": "cartoon", # Use cartoon style
# "color": "whiteCarbon",
# "residue_range": "",
# "around": 0,
# "byres": False,
# "visible": True # Ensure this representation is visible
# }
# ]
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"opacity": 0.2,
},
{
"model": 1,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "cyanCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"opacity": 0.8,
}
]
##
def align_pdb_files(pdb_file_1, pdb_file_2):
# Load the structures
parser = PDB.PPBuilder()
io = PDB.PDBIO()
structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
# Superimpose the second structure onto the first
super_imposer = PDB.Superimposer()
model_1 = structure_1[0]
model_2 = structure_2[0]
# Extract the coordinates from the two structures
atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms
atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
# Align the structures based on the CA atoms
coord_1 = [atom.get_coord() for atom in atoms_1]
coord_2 = [atom.get_coord() for atom in atoms_2]
super_imposer.set_atoms(atoms_1, atoms_2)
super_imposer.apply(model_2) # Apply the transformation to model_2
# Save the aligned structure back to the original file
io.set_structure(structure_2) # Save the aligned structure to the second file (original file)
io.save(pdb_file_2)
# Function to convert .cif to .pdb and save as a temporary file
def convert_cif_to_pdb(cif_path):
"""
Convert a CIF file to a PDB file and save it as a temporary file.
Args:
cif_path (str): Path to the input CIF file.
Returns:
str: Path to the temporary PDB file.
"""
# Initialize the MMCIF parser
parser = MMCIFParser()
structure = parser.get_structure("protein", cif_path)
# Create a temporary file for the PDB output
with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
temp_pdb_path = temp_file.name
# Save the structure as a PDB file
io = PDBIO()
io.set_structure(structure)
io.save(temp_pdb_path)
return temp_pdb_path
def plot_3d(pred_loader):
# Get the CIF file path for the given prediction ID
cif_path = sorted(pred_loader.cif_paths)[0]
# Convert the CIF file to a temporary PDB file
temp_pdb_path = convert_cif_to_pdb(cif_path)
return temp_pdb_path, cif_path
def parse_json_input(json_data: List[Dict]) -> Dict:
"""Convert Protenix JSON format to UI-friendly structure"""
components = {
"protein_chains": [],
"dna_sequences": [],
"ligands": [],
"complex_name": ""
}
for entry in json_data:
components["complex_name"] = entry.get("name", "")
for seq in entry["sequences"]:
if "proteinChain" in seq:
components["protein_chains"].append({
"sequence": seq["proteinChain"]["sequence"],
"count": seq["proteinChain"]["count"]
})
elif "dnaSequence" in seq:
components["dna_sequences"].append({
"sequence": seq["dnaSequence"]["sequence"],
"count": seq["dnaSequence"]["count"]
})
elif "ligand" in seq:
components["ligands"].append({
"type": seq["ligand"]["ligand"],
"count": seq["ligand"]["count"]
})
return components
def create_protenix_json(input_data: Dict) -> List[Dict]:
sequences = []
# Process protein chains
for pc in input_data.get("protein_chains", []):
# Check that the row has both columns and the sequence is nonempty.
if len(pc) >= 2 and pc[0].strip():
sequences.append({
"proteinChain": {
"sequence": pc[0].strip(),
"count": int(pc[1]) if pc[1] else 1
}
})
# Process DNA sequences
for dna in input_data.get("dna_sequences", []):
if len(dna) >= 2 and dna[0].strip():
sequences.append({
"dnaSequence": {
"sequence": dna[0].strip(),
"count": int(dna[1]) if dna[1] else 1
}
})
# Process RNA sequences
for rna in input_data.get("rna_sequences", []):
if len(rna) >= 2 and rna[0].strip():
sequences.append({
"rnaSequence": {
"sequence": rna[0].strip(),
"count": int(rna[1]) if rna[1] else 1
}
})
# Process ligands
for lig in input_data.get("ligands", []):
if len(lig) >= 2 and lig[0].strip():
sequences.append({
"ligand": {
"ligand": lig[0].strip(),
"count": int(lig[1]) if lig[1] else 1
}
})
return [{
"sequences": sequences,
"name": input_data.get("complex_name")+f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:3]}"
}]
def update_json(complex_name, protein_chains, dna_sequences, rna_sequences, ligands):
sequences_list = []
# Process protein chains (DataFrame with headers: ["Sequence", "Count"])
if protein_chains:
for row in protein_chains:
# Check if the row is valid and non-empty
if row and len(row) >= 2 and row[0]:
sequences_list.append({
"proteinChain": {
"sequence": row[0],
"count": row[1]
}
})
# Process DNA sequences
if dna_sequences:
for row in dna_sequences:
if row and len(row) >= 2 and row[0]:
sequences_list.append({
"dnaSequence": {
"sequence": row[0],
"count": row[1]
}
})
# Process RNA sequences
if rna_sequences:
for row in rna_sequences:
if row and len(row) >= 2 and row[0]:
sequences_list.append({
"rnaSequence": {
"sequence": row[0],
"count": row[1]
}
})
# Process ligands (DataFrame with headers: ["Ligand Type", "Count"])
if ligands:
for row in ligands:
if row and len(row) >= 2 and row[0]:
sequences_list.append({
"ligand": {
"ligand": row[0],
"count": row[1]
}
})
return {
"sequences": sequences_list,
"name": complex_name
}
#@torch.inference_mode()
@spaces.GPU(duration=180) # Specify a duration to avoid timeout
def predict_structure(input_collector: dict):
#first initialize runner
runner = InferenceRunner(configs)
"""Handle both input types"""
os.makedirs("./output", exist_ok=True)
# Generate random filename with timestamp
random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
save_path = os.path.join("./output", f"{random_name}.json")
print(input_collector)
# Handle JSON input
if "json" in input_collector:
# Handle different input types
if isinstance(input_collector["json"], str): # Example JSON case (file path)
input_data = json.load(open(input_collector["json"]))
elif hasattr(input_collector["json"], "name"): # File upload case
input_data = json.load(open(input_collector["json"].name))
else: # Direct JSON data case
input_data = input_collector["json"]
else: # Manual input case
input_data = create_protenix_json(input_collector["data"])
with open(save_path, "w") as f:
json.dump(input_data, f, indent=2)
if input_data==example_json and input_collector['watermark']==True:
configs.saved_path = './output/example_output/'
else:
# run msa
json_file = update_infer_json(save_path, './output', True)
# Run prediction
configs.input_json_path = json_file
configs.watermark = input_collector['watermark']
configs.saved_path = os.path.join("./output/", random_name)
infer_predict(runner, configs)
#saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions')
# Generate visualizations
pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions'))
view3d, cif_path = plot_3d(pred_loader=pred_loader)
if configs.watermark:
pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig'))
view3d_orig, _ = plot_3d(pred_loader=pred_loader)
align_pdb_files(view3d, view3d_orig)
view3d = [view3d, view3d_orig]
plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions'))
confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png")
return view3d, confidence_img_path, cif_path
logger = logging.getLogger(__name__)
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
logging.basicConfig(
format=LOG_FORMAT,
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
filemode="w",
)
configs_base["use_deepspeed_evo_attention"] = (
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False"
)
arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
configs = {**configs_base, **{"data": data_configs}, **inference_configs}
configs = parse_configs(
configs=configs,
arg_str=arg_str,
fill_required_with_null=True,
)
configs.load_checkpoint_path='./checkpoint.pt'
download_infercence_cache()
configs.use_deepspeed_evo_attention=False
add_watermark = gr.Checkbox(label="Add Watermark", value=True)
add_watermark1 = gr.Checkbox(label="Add Watermark", value=True)
with gr.Blocks(title="FoldMark", css=custom_css) as demo:
with gr.Row():
# Use a Column to align the logo and title horizontally
gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False)
with gr.Tab("Structure Predictor (JSON Upload)"):
# First create the upload component
json_upload = gr.File(label="Upload JSON", file_types=[".json"])
# Then create the example component that references it
gr.Examples(
examples=[[EXAMPLE_PATH]],
inputs=[json_upload],
label="Click to use example JSON:",
examples_per_page=1
)
# Rest of the components
upload_name = gr.Textbox(label="Complex Name (optional)")
upload_output = gr.JSON(label="Parsed Components")
json_upload.upload(
fn=lambda f: parse_json_input(json.load(open(f.name))),
inputs=json_upload,
outputs=upload_output
)
# Shared prediction components
with gr.Row():
add_watermark.render()
submit_btn = gr.Button("Predict Structure", variant="primary")
#structure_view = gr.HTML(label="3D Visualization")
with gr.Row():
view3d = Molecule3D(label="3D Visualization(Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
# legend = gr.Markdown("""
# **Color Legend:**
# - <span style="color:grey">Gray: Unwatermarked Structure</span>
# - <span style="color:cyan">Cyan: Watermarked Structure</span>
# """)
legend = gr.HTML("""
<div>
<strong>Color Legend:</strong><br>
- <span style="color:grey;">Gray: Unwatermarked Structure</span><br>
- <span style="color:cyan;">Cyan: Watermarked Structure</span>
</div>
""")
with gr.Row():
cif_file = gr.File(label="Download CIF File")
with gr.Row():
confidence_plot_image = gr.Image(label="Confidence Measures")
input_collector = gr.JSON(visible=False)
# Map inputs to a dictionary
submit_btn.click(
fn=lambda j, w: {"json": j, "watermark": w},
inputs=[json_upload, add_watermark],
outputs=input_collector
).then(
fn=predict_structure,
inputs=input_collector,
outputs=[view3d, confidence_plot_image, cif_file]
)
gr.Markdown("""
The example of the uploaded json file for structure prediction.
<pre>
[{
"sequences": [
{
"proteinChain": {
"sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
"count": 2
}
},
{
"dnaSequence": {
"sequence": "CTAGGTAACATTACTCGCG",
"count": 2
}
},
{
"dnaSequence": {
"sequence": "GCGAGTAATGTTAC",
"count": 2
}
},
{
"ligand": {
"ligand": "CCD_PCG",
"count": 2
}
}
],
"name": "7pzb"
}]
</pre>
""")
with gr.Tab("Structure Predictor (Manual Input)"):
with gr.Row():
complex_name = gr.Textbox(label="Complex Name")
# Replace gr.Group with gr.Accordion
with gr.Accordion(label="Protein Chains", open=True):
protein_chains = gr.Dataframe(
headers=["Sequence", "Count"],
datatype=["str", "number"],
row_count=1,
col_count=(2, "fixed"),
type="array"
)
# Repeat for other groups
with gr.Accordion(label="DNA Sequences (A T G C)", open=True):
dna_sequences = gr.Dataframe(
headers=["Sequence", "Count"],
datatype=["str", "number"],
row_count=1,
type="array"
)
with gr.Accordion(label="RNA Sequences (A U G C)", open=True):
rna_sequences = gr.Dataframe(
headers=["Sequence", "Count"],
datatype=["str", "number"],
row_count=1,
type="array"
)
with gr.Accordion(label="Ligands", open=True):
ligands = gr.Dataframe(
headers=["Ligand Type", "Count"],
datatype=["str", "number"],
row_count=1,
type="array"
)
manual_output = gr.JSON(label="Generated JSON")
# Attach a change event to all widgets so that any change updates the JSON output.
for widget in [complex_name, protein_chains, dna_sequences, rna_sequences, ligands]:
widget.change(
fn=update_json,
inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands],
outputs=manual_output
)
# Shared prediction components
with gr.Row():
add_watermark1.render()
submit_btn = gr.Button("Predict Structure", variant="primary")
#structure_view = gr.HTML(label="3D Visualization")
with gr.Row():
view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
with gr.Row():
cif_file = gr.File(label="Download CIF File")
with gr.Row():
confidence_plot_image = gr.Image(label="Confidence Measures")
input_collector = gr.JSON(visible=False)
# Map inputs to a dictionary
submit_btn.click(
fn=lambda c, p, d, r, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "rna_sequences": r, "ligands": l}, "watermark": w},
inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands, add_watermark1],
outputs=input_collector
).then(
fn=predict_structure,
inputs=input_collector,
outputs=[view3d, confidence_plot_image, cif_file]
)
@spaces.GPU(duration=120)
def is_watermarked(file):
#first initialize runner
runner = InferenceRunner(configs)
# Generate a unique subdirectory and filename
unique_id = str(uuid.uuid4().hex[:8])
subdir = os.path.join('./output', unique_id)
os.makedirs(subdir, exist_ok=True)
filename = f"{unique_id}.cif"
file_path = os.path.join(subdir, filename)
# Save the uploaded file to the new location
shutil.copy(file.name, file_path)
#just for fast demonstration, otherwise it takes around 100 seconds
if '7r6r_watermarked' in file.name:
result=True
elif '7pzb_unwatermarked' in file.name:
result=False
else:
# Call your processing functions
configs.process_success = process_data(subdir)
configs.subdir = subdir
result = infer_detect(runner, configs)
# This function should return 'Watermarked' or 'Not Watermarked'
temp_pdb_path = convert_cif_to_pdb(file_path)
if result==False:
return "Not Watermarked", temp_pdb_path
else:
return "Watermarked", temp_pdb_path
with gr.Tab("Watermark Detector"):
# First create the upload component
cif_upload = gr.File(label="Upload .cif", file_types=["..cif"])
with gr.Row():
cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps)
# Prediction output
prediction_output = gr.Textbox(label="Prediction")
# Define the interaction
cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view])
# Example files
example_files = [
"./examples/7r6r_watermarked.cif",
"./examples/7pzb_unwatermarked.cif"
]
gr.Examples(examples=example_files, inputs=cif_upload)
if __name__ == "__main__":
demo.launch(share=True)