{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "a8d16d95", "metadata": {}, "outputs": [], "source": [ "# Copyright 2024 ByteDance and/or its affiliates.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": 3, "id": "81744ffa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Try to find the ccd cache data in the code directory for inference.\n" ] } ], "source": [ "import argparse\n", "import csv\n", "from pathlib import Path\n", "from typing import Optional\n", "\n", "import pandas as pd\n", "from joblib import Parallel, delayed\n", "from tqdm import tqdm\n", "\n", "from protenix.data.data_pipeline import DataPipeline\n", "from protenix.utils.file_io import dump_gzip_pickle" ] }, { "cell_type": "code", "execution_count": null, "id": "02412ab0", "metadata": {}, "outputs": [ { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", "\u001b[1;31mClick here for more info. \n", "\u001b[1;31mView Jupyter log for further details." ] } ], "source": [ "\n", "dataset = \"Distillation\"\n", "\n", "sample_indices_list, bioassembly_dict = DataPipeline.get_data_from_mmcif(\n", " mmcif='./dataset/7pzb.cif', pdb_cluster_file=None, dataset=\"Distillation\"\n", " )\n", "print(bioassembly_dict)\n", "\n", "pdb_id = bioassembly_dict[\"pdb_id\"]\n", "# save to output dir\n", "dump_gzip_pickle(bioassembly_dict, f\"./dataset/{pdb_id}.pkl.gz\")" ] }, { "cell_type": "code", "execution_count": 1, "id": "1ff18a14", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Try to find the ccd cache data in the code directory for inference.\n" ] }, { "ename": "NameError", "evalue": "name 'pdb_id' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 16\u001b[0m\n\u001b[1;32m 12\u001b[0m data_config \u001b[38;5;241m=\u001b[39m configs\u001b[38;5;241m.\u001b[39mdata\n\u001b[1;32m 13\u001b[0m config_dict \u001b[38;5;241m=\u001b[39m data_config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mweightedPDB_before2109_wopb_nometalc_0925\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mto_dict()\n\u001b[1;32m 15\u001b[0m params \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43mpdb_id\u001b[49m,\n\u001b[1;32m 17\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mconfig_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbase_info\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 18\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcropping_configs\u001b[39m\u001b[38;5;124m\"\u001b[39m: config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcropping_configs\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 19\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124merror_dir\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m./dataset\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 20\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmsa_featurizer\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 21\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtemplate_featurizer\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 22\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlig_atom_rename\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 23\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshuffle_mols\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 24\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshuffle_sym_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 25\u001b[0m }\n", "\u001b[0;31mNameError\u001b[0m: name 'pdb_id' is not defined" ] } ], "source": [ "from configs.configs_base import configs as configs_base\n", "from configs.configs_data import data_configs\n", "from configs.configs_inference import inference_configs\n", "from protenix.config import parse_configs\n", "arg_str = \"--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 \"\n", "configs = {**configs_base, **{\"data\": data_configs}, **inference_configs}\n", "configs = parse_configs(\n", " configs=configs,\n", " arg_str=arg_str,\n", " fill_required_with_null=True,\n", ")\n", "data_config = configs.data\n", "config_dict = data_config[\"weightedPDB_before2109_wopb_nometalc_0925\"].to_dict()\n", "\n", "params = {\n", " \"name\": pdb_id,\n", " **config_dict[\"base_info\"],\n", " \"cropping_configs\": config_dict[\"cropping_configs\"],\n", " \"error_dir\": './dataset',\n", " \"msa_featurizer\": None,\n", " \"template_featurizer\": None,\n", " \"lig_atom_rename\": False,\n", " \"shuffle_mols\": False,\n", " \"shuffle_sym_ids\": False,\n", " }" ] }, { "cell_type": "code", "execution_count": 5, "id": "768a767a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n", "
\n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import display\n", "from protenix.web_service.prediction_visualization import *\n", "\n", "prediction_fpath='../output/7pzb/seed_112/predictions/'\n", "\n", "pred_loader = PredictionLoader(prediction_fpath)\n", "html_content = plot_3d(pred_id=0, pred_loader=pred_loader)\n", "display(html_content)" ] }, { "cell_type": "code", "execution_count": 2, "id": "104f7a6b-e998-46d5-a8f5-0528d4c4abc0", "metadata": { "metadata": {}, "tags": [] }, "outputs": [], "source": [ "from IPython.display import display\n", "import ipywidgets as widgets\n", "import json\n", "from copy import deepcopy\n", "import os\n", "\n", "\n", "from protenix.web_service.colab_request_parser import RequestParser\n", "from protenix.web_service.viewer import ProtenixInputViewer\n", "from protenix.web_service.prediction_visualization import (\n", " plot_confidence_measures_from_pred,\n", " plot_contact_maps_from_pred,\n", " PredictionLoader,\n", " plot_3d,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Input - information about input complex\n", "\n", "## Add global information\n", "- **name**: System name\n", "- **use_msa**: Whether to use MSA information during inference.\n", "- **atom_confidence**: Whether to return the confidence of each atom during inference.\n", "## Add Entity\n", "For each entity, you need to select the type, set Copies and Sequence or SMILES or CCD information.\n", "- Click **+Dna/Rna/Protein** to add Dna/Rna/Protein sequence.\n", "- Click **+Ligand(SMILES)** to add Ligand SMILES.\n", "- Click **+Ligand/Ion CCD** to add Ligand/Ion CCD.\n", "### Add Modification for Dna/Rna/Protein\n", "Each Dna/Rna/Protein entity can have multiple modifications. Modification information includes modification type and position.\n", "- Click **+modification** to add a modification.\n", "## Add covalent bond\n", "Each covalent bond contains left/right entity, left/right pos and left/right atom information.\n", "- Click **+convalent_bond** to add a covalent bond.\n", "## Add model inference information\n", "- **N_sample**: The number of samples generated by one seed (diffusion model).\n", "- **N_step**: diffusion model step.\n", "- **N_cycle**: The number of cycles during backbone inference.\n", "- **seeds**: Random seeds.\n", "- **version**: v1/v2/v3/v4/v5. Currently only v1 is supported." ] }, { "cell_type": "code", "execution_count": 3, "id": "19e69bdd", "metadata": { "metadata": {} }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "be1963c1fcd7400ebcdcfcd25a6e1cf7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "ProtenixInputViewer(children=(HBox(children=(Text(value='', description='name', placeholder='name'), Checkbox(…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\"\"\"\n", "Additional user specified arguments\n", " - need_all_atom_confidence (bool), if saving per-atom confidence metrics (such as atom_plddt)\n", " - \n", "\"\"\"\n", "\n", "viewer = ProtenixInputViewer()\n", "display(viewer)" ] }, { "cell_type": "markdown", "id": "ce116461", "metadata": {}, "source": [ "# Input - information about results saving" ] }, { "cell_type": "code", "execution_count": 4, "id": "4073bb4d", "metadata": {}, "outputs": [], "source": [ "save_dir = \"./demo\"\n", "if os.path.exists(save_dir):\n", " print(\n", " (\n", " f\"[WARNING]: Results will be saved to an existing path:\\n{save_dir}.\\n\"\n", " \"Please verify that this is the intended location before proceeding.\"\n", " )\n", " )\n", "os.makedirs(save_dir, exist_ok=True)" ] }, { "cell_type": "markdown", "id": "778fb57f", "metadata": {}, "source": [ "# Inference code body" ] }, { "cell_type": "code", "execution_count": 5, "id": "8bce87c0-dc4f-4692-801c-ff1c4453b062", "metadata": { "metadata": {}, "tags": [] }, "outputs": [ { "ename": "AssertionError", "evalue": "name is empty", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m input_dict \u001b[38;5;241m=\u001b[39m \u001b[43mviewer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m output \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mdumps(input_dict, ensure_ascii\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, indent\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(input_dict_path\u001b[38;5;241m:=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msave_dir\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/input.json\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m fid:\n", "File \u001b[0;32m/n/holylabs/LABS/mzitnik_lab/Users/zaixizhang/Protenix_new/protenix/web_service/viewer.py:578\u001b[0m, in \u001b[0;36mProtenixInputViewer.get_result\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 576\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m()\n\u001b[1;32m 577\u001b[0m result[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;241m.\u001b[39mvalue\n\u001b[0;32m--> 578\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(result[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname is empty\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 579\u001b[0m result[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse_msa\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_msa\u001b[38;5;241m.\u001b[39mvalue\n\u001b[1;32m 580\u001b[0m result[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124matom_confidence\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39matom_confidence\u001b[38;5;241m.\u001b[39mvalue\n", "\u001b[0;31mAssertionError\u001b[0m: name is empty" ] } ], "source": [ "input_dict = viewer.get_result()\n", "output = json.dumps(input_dict, ensure_ascii=False, indent=4)\n", "with open(input_dict_path:=f\"{save_dir}/input.json\", \"w\") as fid:\n", " fid.write(output)\n", "\n", "print(f\"Input information saved to:\\n{input_dict_path}\")\n", "print(\"Please verify the input information below:\")\n", "! cat $input_dict_path" ] }, { "cell_type": "code", "execution_count": null, "id": "2e499326-61ab-4bd1-9a83-9c1f82f0fc56", "metadata": { "metadata": {}, "tags": [] }, "outputs": [], "source": [ "parser = RequestParser(\n", " request_json_path=input_dict_path, request_dir=save_dir\n", ")\n", "parser.launch()" ] }, { "cell_type": "markdown", "id": "8df5ec4a-6bc4-4520-b45a-0f3964b01ad0", "metadata": {}, "source": [ "# Visualize predicted structures" ] }, { "cell_type": "code", "execution_count": null, "id": "211d882e", "metadata": { "metadata": {} }, "outputs": [], "source": [ "seed = 1 # which seed do you want to visualize?\n", "\n", "job_name = input_dict[\"name\"]\n", "pred_base_path = f\"{save_dir}/{job_name}/seed_{seed}/predictions\"\n", "pred_loader = PredictionLoader(pred_base_path)" ] }, { "cell_type": "markdown", "id": "2d205c54", "metadata": {}, "source": [ "### Visualizing contact maps" ] }, { "cell_type": "code", "execution_count": null, "id": "f423d43d-f749-4e19-8aa6-9670a441c7ce", "metadata": { "metadata": {} }, "outputs": [], "source": [ "# Visualizing contact maps\n", "\n", "# String input widget\n", "rep_atom = widgets.Text(value=\"CA\", description=\"Rep Atom\", disabled=False)\n", "\n", "# Float input widget\n", "threshold = widgets.FloatSlider(\n", " value=10,\n", " min=1,\n", " max=20.0,\n", " step=0.1,\n", " description=\"Threshold\",\n", " continuous_update=False,\n", ")\n", "\n", "\n", "def plot_wrapper(threshold: float, rep_atom: str):\n", " plot_contact_maps_from_pred(\n", " preds=pred_loader.preds,\n", " fnames=pred_loader.fnames,\n", " threshold=threshold,\n", " rep_atom=rep_atom,\n", " )\n", "\n", "\n", "interactive_plot = widgets.interactive(\n", " plot_wrapper, threshold=threshold, rep_atom=rep_atom\n", ")\n", "\n", "display(interactive_plot)" ] }, { "cell_type": "markdown", "id": "fa95ad95", "metadata": {}, "source": [ "### Visualizing confidence metrics" ] }, { "cell_type": "code", "execution_count": null, "id": "237ddde4-950d-446d-b9b1-45436683bbbf", "metadata": { "metadata": {} }, "outputs": [], "source": [ "# Visualizing confidence metrics\n", "if not pred_loader.full_confidence_data:\n", " print(\"All-atom confidence metrics not saved.\")\n", "else:\n", " bool_option = widgets.Checkbox(value=True, description=\"Print global metrics\")\n", "\n", " def plot_wrapper(show_global_confidence: bool):\n", " plot_confidence_measures_from_pred(\n", " full_confidence=pred_loader.full_confidence_data,\n", " summary_confidence=pred_loader.summary_confidence_data,\n", " fnames=pred_loader.fnames,\n", " show_global_confidence=show_global_confidence,\n", " )\n", "\n", " interactive_plot = widgets.interactive(\n", " plot_wrapper, show_global_confidence=bool_option\n", " )\n", "\n", " display(interactive_plot)" ] }, { "cell_type": "markdown", "id": "55e9bfde", "metadata": {}, "source": [ "### Visualizing 3d structure" ] }, { "cell_type": "code", "execution_count": null, "id": "f4b6184d-b3bb-4876-817d-284535028fa1", "metadata": { "metadata": {}, "tags": [] }, "outputs": [], "source": [ "# pred_id: the i-th prediction from the model\n", "pred_id = 0\n", "plot_3d(\n", " pred_id=pred_id,\n", " pred_loader=pred_loader,\n", " show_sidechains=False,\n", " show_mainchains=False,\n", " color=\"rainbow\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "53adbbbb-929d-4f23-80c7-3d1c2eb5d783", "metadata": { "metadata": {}, "tags": [] }, "outputs": [], "source": [ "if not pred_loader.full_confidence_data:\n", " plot_3d_func = lambda *args, **kwargs: None\n", " print(\"All-atom confidence metrics not saved. Hence plotting by pLDDT disabled.\")\n", "else:\n", " plot_3d_func = deepcopy(plot_3d)\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "id": "231fc13c", "metadata": {}, "outputs": [], "source": [ "plot_3d_func(\n", " pred_id=pred_id,\n", " pred_loader=pred_loader,\n", " show_sidechains=False,\n", " show_mainchains=False,\n", " color=\"pLDDT\",\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "pro_new", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }