FoldMark / runner /inference.py
Zaixi's picture
worker=0
fe8f637
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import gdown
import traceback
import urllib.request
from contextlib import nullcontext
from os.path import exists as opexists
from os.path import join as opjoin
from typing import Any, Mapping
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from configs.configs_base import configs as configs_base
from configs.configs_data import data_configs
from configs.configs_inference import inference_configs
from runner.dumper import DataDumper
from protenix.config import parse_configs, parse_sys_args
from protenix.data.infer_data_pipeline import get_inference_dataloader
from protenix.model.protenix_edit import Protenix
from protenix.utils.distributed import DIST_WRAPPER
from protenix.utils.seed import seed_everything
from protenix.utils.torch_utils import to_device
from protenix.data.dataset import BaseSingleDataset
from protenix.web_service.dependency_url import URL
logger = logging.getLogger(__name__)
def get_recovery(pred_code, gt_code):
pred_code = torch.clamp(pred_code, min=-10, max=10)
epsilon = 1e-6
predicted_classes = (torch.sigmoid(pred_code) > 0.5).float()
# Use the epsilon to ensure no division by zero in recovery computation
recovery = ((predicted_classes == gt_code.float()).float().mean(dim=-1)).clamp(min=epsilon)
return recovery.mean()
class InferenceRunner(object):
def __init__(self, configs: Any) -> None:
self.configs = configs
self.init_env()
self.init_basics()
self.init_model()
self.load_checkpoint()
self.init_dumper(
need_atom_confidence=configs.need_atom_confidence,
sorted_by_ranking_score=configs.sorted_by_ranking_score,
)
def init_env(self) -> None:
self.print(
f"Distributed environment: world size: {DIST_WRAPPER.world_size}, "
+ f"global rank: {DIST_WRAPPER.rank}, local rank: {DIST_WRAPPER.local_rank}"
)
self.use_cuda = torch.cuda.device_count() > 0
if self.use_cuda:
self.device = torch.device("cuda:{}".format(DIST_WRAPPER.local_rank))
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count()))
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
logging.info(
f"LOCAL_RANK: {DIST_WRAPPER.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]"
)
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
if DIST_WRAPPER.world_size > 1:
dist.init_process_group(backend="nccl")
if self.configs.use_deepspeed_evo_attention:
env = os.getenv("CUTLASS_PATH", None)
self.print(f"env: {env}")
assert (
env is not None
), "if use ds4sci, set `CUTLASS_PATH` env as https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/"
if env is not None:
logging.info(
"The kernels will be compiled when DS4Sci_EvoformerAttention is called for the first time."
)
use_fastlayernorm = os.getenv("LAYERNORM_TYPE", None)
if use_fastlayernorm == "fast_layernorm":
logging.info(
"The kernels will be compiled when fast_layernorm is called for the first time."
)
logging.info("Finished init ENV.")
def init_basics(self) -> None:
self.dump_dir = self.configs.dump_dir
self.error_dir = opjoin(self.dump_dir, "ERR")
os.makedirs(self.dump_dir, exist_ok=True)
os.makedirs(self.error_dir, exist_ok=True)
def init_model(self) -> None:
self.model = Protenix(self.configs).to(self.device)
def load_checkpoint(self) -> None:
checkpoint_path = self.configs.load_checkpoint_path
if not os.path.exists(checkpoint_path):
raise Exception(f"Given checkpoint path not exist [{checkpoint_path}]")
self.print(
f"Loading from {checkpoint_path}, strict: {self.configs.load_strict}"
)
checkpoint = torch.load(checkpoint_path, self.device)
sample_key = [k for k in checkpoint["model"].keys()][0]
self.print(f"Sampled key: {sample_key}")
if sample_key.startswith("module."): # DDP checkpoint has module. prefix
checkpoint["model"] = {
k[len("module.") :]: v for k, v in checkpoint["model"].items()
}
self.model.load_state_dict(
state_dict=checkpoint["model"],
strict=self.configs.load_strict,
)
self.model.eval()
self.print(f"Finish loading checkpoint.")
def init_dumper(
self, need_atom_confidence: bool = False, sorted_by_ranking_score: bool = True
):
self.dumper = DataDumper(
base_dir=self.dump_dir,
need_atom_confidence=need_atom_confidence,
sorted_by_ranking_score=sorted_by_ranking_score,
)
# Adapted from runner.train.Trainer.evaluate
@torch.no_grad()
def predict(self, data: Mapping[str, Mapping[str, Any]], watermark=False) -> dict[str, torch.Tensor]:
eval_precision = {
"fp32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[self.configs.dtype]
enable_amp = (
torch.autocast(device_type="cuda", dtype=eval_precision)
if torch.cuda.is_available()
else nullcontext()
)
data = to_device(data, self.device)
with enable_amp:
prediction, label_dict, _ = self.model(
input_feature_dict=data["input_feature_dict"],
label_full_dict=None,
label_dict=None,
mode="inference",
watermark=watermark
)
return prediction, label_dict
# Adapted from runner.train.Trainer.evaluate
@torch.no_grad()
def detect(self, data: Mapping[str, Mapping[str, Any]]) -> dict[str, torch.Tensor]:
eval_precision = {
"fp32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[self.configs.dtype]
enable_amp = (
torch.autocast(device_type="cuda", dtype=eval_precision)
if torch.cuda.is_available()
else nullcontext()
)
data = to_device(data, self.device)
with enable_amp:
prediction, label_dict, _ = self.model(
input_feature_dict=data["input_feature_dict"],
label_dict=data["label_dict"],
label_full_dict=data["label_full_dict"],
mode="inference",
detect=True
)
return prediction, label_dict
def print(self, msg: str):
if DIST_WRAPPER.rank == 0:
logger.info(msg)
def update_model_configs(self, new_configs: Any) -> None:
self.model.configs = new_configs
def download_infercence_cache() -> None:
code_directory = './'
data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache")
os.makedirs(data_cache_dir, exist_ok=True)
if not os.path.exists('./release_data/ccd_cache/components.v20240608.cif'):
file_id = '1e8wxpuEB-0xL_3dlMfZCFo6cL5oSHSUK'
download_url = f'https://drive.google.com/uc?id={file_id}'
output_file = './release_data/ccd_cache/components.v20240608.cif'
gdown.download(download_url, output_file, quiet=False)
if not os.path.exists('./release_data/ccd_cache/components.v20240608.cif.rdkit_mol.pkl'):
file_id = '1R9d678aBfQwTd0Rh15doRmW-fETNdeWf'
# Construct the download URL
download_url = f'https://drive.google.com/uc?id={file_id}'
# Specify the output file name
output_file = './release_data/ccd_cache/components.v20240608.cif.rdkit_mol.pkl'
gdown.download(download_url, output_file, quiet=False)
if not os.path.exists('./checkpoint.pt'):
# Google Drive file ID
file_id = '17zBIRed3xZM8ux0bq2hpf1oFC75Y7OEw'
# URL to download the file
url = f'https://drive.google.com/uc?id={file_id}'
# Download the file and save it as 'checkpoint.pt'
gdown.download(url, './checkpoint.pt', quiet=False)
# checkpoint_path = configs.load_checkpoint_path
# if not opexists(checkpoint_path):
# checkpoint_path = os.path.join(
# code_directory, f"release_data/checkpoint/model_{model_version}.pt"
# )
# os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
# tos_url = URL[f"model_{model_version}"]
# logger.info(f"Downloading model checkpoint from\n {tos_url}...")
# urllib.request.urlretrieve(tos_url, checkpoint_path)
# try:
# ckpt = torch.load(checkpoint_path)
# del ckpt
# except:
# os.remove(checkpoint_path)
# raise RuntimeError(
# "Download model checkpoint failed, please download by yourself with "
# f"wget {tos_url} -O {checkpoint_path}"
# )
# configs.load_checkpoint_path = checkpoint_path
def update_inference_configs(configs: Any, N_token: int):
# Setting the default inference configs for different N_token and N_atom
# when N_token is larger than 3000, the default config might OOM even on a
# A100 80G GPUS,
if N_token > 3840:
configs.skip_amp.confidence_head = False
configs.skip_amp.sample_diffusion = False
elif N_token > 2560:
configs.skip_amp.confidence_head = False
configs.skip_amp.sample_diffusion = True
else:
configs.skip_amp.confidence_head = True
configs.skip_amp.sample_diffusion = True
return configs
def infer_detect(runner: InferenceRunner, configs: Any) -> None:
try:
data_config = configs.data
config_dict = data_config["recentPDB_1536_sample384_0925"].to_dict()
config_dict['base_info']['bioassembly_dict_dir']=configs.subdir
config_dict['base_info']['indices_fpath']=os.path.join(configs.subdir, 'output.csv')
config_dict['base_info']['pdb_list']=""
params = {
"name": 'detection_data',
**config_dict["base_info"],
"cropping_configs": config_dict["cropping_configs"],
"error_dir": configs.subdir,
"msa_featurizer": None,
"template_featurizer": None,
"lig_atom_rename": False,
"shuffle_mols": False,
"shuffle_sym_ids": False,
}
test_dataset = BaseSingleDataset(**params)
test_sampler = None
dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=0,
sampler=test_sampler,
collate_fn=lambda batch: batch[0],
)
except Exception as e:
error_message = f"{e}:\n{traceback.format_exc()}"
logger.info(error_message)
with open(opjoin(runner.error_dir, "error.txt"), "a") as f:
f.write(error_message)
return
for batch in dataloader:
prediction, label_dict = runner.detect(batch)
#print(prediction['watermark'])
pred_code = torch.clamp(prediction['watermark'], min=-10, max=10)
predicted_class = torch.sigmoid(pred_code).item() > 0.5
# logger.info(
# f"Recovery: {get_recovery(prediction['watermark'], label_dict['watermark'])}"
# )
return predicted_class and configs.process_success
def infer_predict(runner: InferenceRunner, configs: Any) -> None:
# Data
logger.info(f"Loading data from\n{configs.input_json_path}")
try:
dataloader = get_inference_dataloader(configs=configs)
except Exception as e:
error_message = f"{e}:\n{traceback.format_exc()}"
logger.info(error_message)
with open(opjoin(runner.error_dir, "error.txt"), "a") as f:
f.write(error_message)
return
num_data = len(dataloader.dataset)
for seed in configs.seeds:
seed_everything(seed=seed, deterministic=configs.deterministic)
for batch in dataloader:
try:
data, atom_array, data_error_message = batch[0]
sample_name = data["sample_name"]
if len(data_error_message) > 0:
logger.info(data_error_message)
with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
f.write(data_error_message)
continue
logger.info(
(
f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: "
f"N_asym {data['N_asym'].item()}, N_token {data['N_token'].item()}, "
f"N_atom {data['N_atom'].item()}, N_msa {data['N_msa'].item()}"
)
)
new_configs = update_inference_configs(configs, data["N_token"].item())
runner.update_model_configs(new_configs)
prediction, label_dict = runner.predict(data, configs.watermark)
runner.dumper.dump(
dataset_name="",
pdb_id=sample_name,
seed=seed,
pred_dict=prediction,
atom_array=atom_array,
entity_poly_type=data["entity_poly_type"],
saved_path=configs.saved_path,
)
logger.info(
f"[Rank {DIST_WRAPPER.rank}] {data['sample_name']} succeeded.\n"
f"Results saved to {configs.dump_dir}"
)
# logger.info(
# f"Recovery: {get_recovery(prediction['watermark'], label_dict['watermark'])}"
# )
torch.cuda.empty_cache()
except Exception as e:
error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}"
logger.info(error_message)
# Save error info
with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
f.write(error_message)
if hasattr(torch.cuda, "empty_cache"):
torch.cuda.empty_cache()
return sample_name, seed
def main(configs: Any) -> None:
# Runner
runner = InferenceRunner(configs)
infer_predict(runner, configs)
def run() -> None:
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
logging.basicConfig(
format=LOG_FORMAT,
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
filemode="w",
)
configs_base["use_deepspeed_evo_attention"] = (
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true"
)
configs = {**configs_base, **{"data": data_configs}, **inference_configs}
configs = parse_configs(
configs=configs,
arg_str=parse_sys_args(),
fill_required_with_null=True,
)
download_infercence_cache(configs, model_version="v0.2.0")
main(configs)
if __name__ == "__main__":
run()