# Copyright 2024 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import gdown import traceback import urllib.request from contextlib import nullcontext from os.path import exists as opexists from os.path import join as opjoin from typing import Any, Mapping import torch import torch.distributed as dist from torch.utils.data import DataLoader from configs.configs_base import configs as configs_base from configs.configs_data import data_configs from configs.configs_inference import inference_configs from runner.dumper import DataDumper from protenix.config import parse_configs, parse_sys_args from protenix.data.infer_data_pipeline import get_inference_dataloader from protenix.model.protenix_edit import Protenix from protenix.utils.distributed import DIST_WRAPPER from protenix.utils.seed import seed_everything from protenix.utils.torch_utils import to_device from protenix.data.dataset import BaseSingleDataset from protenix.web_service.dependency_url import URL logger = logging.getLogger(__name__) def get_recovery(pred_code, gt_code): pred_code = torch.clamp(pred_code, min=-10, max=10) epsilon = 1e-6 predicted_classes = (torch.sigmoid(pred_code) > 0.5).float() # Use the epsilon to ensure no division by zero in recovery computation recovery = ((predicted_classes == gt_code.float()).float().mean(dim=-1)).clamp(min=epsilon) return recovery.mean() class InferenceRunner(object): def __init__(self, configs: Any) -> None: self.configs = configs self.init_env() self.init_basics() self.init_model() self.load_checkpoint() self.init_dumper( need_atom_confidence=configs.need_atom_confidence, sorted_by_ranking_score=configs.sorted_by_ranking_score, ) def init_env(self) -> None: self.print( f"Distributed environment: world size: {DIST_WRAPPER.world_size}, " + f"global rank: {DIST_WRAPPER.rank}, local rank: {DIST_WRAPPER.local_rank}" ) self.use_cuda = torch.cuda.device_count() > 0 if self.use_cuda: self.device = torch.device("cuda:{}".format(DIST_WRAPPER.local_rank)) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count())) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) logging.info( f"LOCAL_RANK: {DIST_WRAPPER.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]" ) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") if DIST_WRAPPER.world_size > 1: dist.init_process_group(backend="nccl") if self.configs.use_deepspeed_evo_attention: env = os.getenv("CUTLASS_PATH", None) self.print(f"env: {env}") assert ( env is not None ), "if use ds4sci, set `CUTLASS_PATH` env as https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/" if env is not None: logging.info( "The kernels will be compiled when DS4Sci_EvoformerAttention is called for the first time." ) use_fastlayernorm = os.getenv("LAYERNORM_TYPE", None) if use_fastlayernorm == "fast_layernorm": logging.info( "The kernels will be compiled when fast_layernorm is called for the first time." ) logging.info("Finished init ENV.") def init_basics(self) -> None: self.dump_dir = self.configs.dump_dir self.error_dir = opjoin(self.dump_dir, "ERR") os.makedirs(self.dump_dir, exist_ok=True) os.makedirs(self.error_dir, exist_ok=True) def init_model(self) -> None: self.model = Protenix(self.configs).to(self.device) def load_checkpoint(self) -> None: checkpoint_path = self.configs.load_checkpoint_path if not os.path.exists(checkpoint_path): raise Exception(f"Given checkpoint path not exist [{checkpoint_path}]") self.print( f"Loading from {checkpoint_path}, strict: {self.configs.load_strict}" ) checkpoint = torch.load(checkpoint_path, self.device) sample_key = [k for k in checkpoint["model"].keys()][0] self.print(f"Sampled key: {sample_key}") if sample_key.startswith("module."): # DDP checkpoint has module. prefix checkpoint["model"] = { k[len("module.") :]: v for k, v in checkpoint["model"].items() } self.model.load_state_dict( state_dict=checkpoint["model"], strict=self.configs.load_strict, ) self.model.eval() self.print(f"Finish loading checkpoint.") def init_dumper( self, need_atom_confidence: bool = False, sorted_by_ranking_score: bool = True ): self.dumper = DataDumper( base_dir=self.dump_dir, need_atom_confidence=need_atom_confidence, sorted_by_ranking_score=sorted_by_ranking_score, ) # Adapted from runner.train.Trainer.evaluate @torch.no_grad() def predict(self, data: Mapping[str, Mapping[str, Any]], watermark=False) -> dict[str, torch.Tensor]: eval_precision = { "fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16, }[self.configs.dtype] enable_amp = ( torch.autocast(device_type="cuda", dtype=eval_precision) if torch.cuda.is_available() else nullcontext() ) data = to_device(data, self.device) with enable_amp: prediction, label_dict, _ = self.model( input_feature_dict=data["input_feature_dict"], label_full_dict=None, label_dict=None, mode="inference", watermark=watermark ) return prediction, label_dict # Adapted from runner.train.Trainer.evaluate @torch.no_grad() def detect(self, data: Mapping[str, Mapping[str, Any]]) -> dict[str, torch.Tensor]: eval_precision = { "fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16, }[self.configs.dtype] enable_amp = ( torch.autocast(device_type="cuda", dtype=eval_precision) if torch.cuda.is_available() else nullcontext() ) data = to_device(data, self.device) with enable_amp: prediction, label_dict, _ = self.model( input_feature_dict=data["input_feature_dict"], label_dict=data["label_dict"], label_full_dict=data["label_full_dict"], mode="inference", detect=True ) return prediction, label_dict def print(self, msg: str): if DIST_WRAPPER.rank == 0: logger.info(msg) def update_model_configs(self, new_configs: Any) -> None: self.model.configs = new_configs def download_infercence_cache() -> None: code_directory = './' data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache") os.makedirs(data_cache_dir, exist_ok=True) if not os.path.exists('./release_data/ccd_cache/components.v20240608.cif'): file_id = '1e8wxpuEB-0xL_3dlMfZCFo6cL5oSHSUK' download_url = f'https://drive.google.com/uc?id={file_id}' output_file = './release_data/ccd_cache/components.v20240608.cif' gdown.download(download_url, output_file, quiet=False) if not os.path.exists('./release_data/ccd_cache/components.v20240608.cif.rdkit_mol.pkl'): file_id = '1R9d678aBfQwTd0Rh15doRmW-fETNdeWf' # Construct the download URL download_url = f'https://drive.google.com/uc?id={file_id}' # Specify the output file name output_file = './release_data/ccd_cache/components.v20240608.cif.rdkit_mol.pkl' gdown.download(download_url, output_file, quiet=False) if not os.path.exists('./checkpoint.pt'): # Google Drive file ID file_id = '17zBIRed3xZM8ux0bq2hpf1oFC75Y7OEw' # URL to download the file url = f'https://drive.google.com/uc?id={file_id}' # Download the file and save it as 'checkpoint.pt' gdown.download(url, './checkpoint.pt', quiet=False) # checkpoint_path = configs.load_checkpoint_path # if not opexists(checkpoint_path): # checkpoint_path = os.path.join( # code_directory, f"release_data/checkpoint/model_{model_version}.pt" # ) # os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) # tos_url = URL[f"model_{model_version}"] # logger.info(f"Downloading model checkpoint from\n {tos_url}...") # urllib.request.urlretrieve(tos_url, checkpoint_path) # try: # ckpt = torch.load(checkpoint_path) # del ckpt # except: # os.remove(checkpoint_path) # raise RuntimeError( # "Download model checkpoint failed, please download by yourself with " # f"wget {tos_url} -O {checkpoint_path}" # ) # configs.load_checkpoint_path = checkpoint_path def update_inference_configs(configs: Any, N_token: int): # Setting the default inference configs for different N_token and N_atom # when N_token is larger than 3000, the default config might OOM even on a # A100 80G GPUS, if N_token > 3840: configs.skip_amp.confidence_head = False configs.skip_amp.sample_diffusion = False elif N_token > 2560: configs.skip_amp.confidence_head = False configs.skip_amp.sample_diffusion = True else: configs.skip_amp.confidence_head = True configs.skip_amp.sample_diffusion = True return configs def infer_detect(runner: InferenceRunner, configs: Any) -> None: try: data_config = configs.data config_dict = data_config["recentPDB_1536_sample384_0925"].to_dict() config_dict['base_info']['bioassembly_dict_dir']=configs.subdir config_dict['base_info']['indices_fpath']=os.path.join(configs.subdir, 'output.csv') config_dict['base_info']['pdb_list']="" params = { "name": 'detection_data', **config_dict["base_info"], "cropping_configs": config_dict["cropping_configs"], "error_dir": configs.subdir, "msa_featurizer": None, "template_featurizer": None, "lig_atom_rename": False, "shuffle_mols": False, "shuffle_sym_ids": False, } test_dataset = BaseSingleDataset(**params) test_sampler = None dataloader = DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=0, sampler=test_sampler, collate_fn=lambda batch: batch[0], ) except Exception as e: error_message = f"{e}:\n{traceback.format_exc()}" logger.info(error_message) with open(opjoin(runner.error_dir, "error.txt"), "a") as f: f.write(error_message) return for batch in dataloader: prediction, label_dict = runner.detect(batch) #print(prediction['watermark']) pred_code = torch.clamp(prediction['watermark'], min=-10, max=10) predicted_class = torch.sigmoid(pred_code).item() > 0.5 # logger.info( # f"Recovery: {get_recovery(prediction['watermark'], label_dict['watermark'])}" # ) return predicted_class and configs.process_success def infer_predict(runner: InferenceRunner, configs: Any) -> None: # Data logger.info(f"Loading data from\n{configs.input_json_path}") try: dataloader = get_inference_dataloader(configs=configs) except Exception as e: error_message = f"{e}:\n{traceback.format_exc()}" logger.info(error_message) with open(opjoin(runner.error_dir, "error.txt"), "a") as f: f.write(error_message) return num_data = len(dataloader.dataset) for seed in configs.seeds: seed_everything(seed=seed, deterministic=configs.deterministic) for batch in dataloader: try: data, atom_array, data_error_message = batch[0] sample_name = data["sample_name"] if len(data_error_message) > 0: logger.info(data_error_message) with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: f.write(data_error_message) continue logger.info( ( f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: " f"N_asym {data['N_asym'].item()}, N_token {data['N_token'].item()}, " f"N_atom {data['N_atom'].item()}, N_msa {data['N_msa'].item()}" ) ) new_configs = update_inference_configs(configs, data["N_token"].item()) runner.update_model_configs(new_configs) prediction, label_dict = runner.predict(data, configs.watermark) runner.dumper.dump( dataset_name="", pdb_id=sample_name, seed=seed, pred_dict=prediction, atom_array=atom_array, entity_poly_type=data["entity_poly_type"], saved_path=configs.saved_path, ) logger.info( f"[Rank {DIST_WRAPPER.rank}] {data['sample_name']} succeeded.\n" f"Results saved to {configs.dump_dir}" ) # logger.info( # f"Recovery: {get_recovery(prediction['watermark'], label_dict['watermark'])}" # ) torch.cuda.empty_cache() except Exception as e: error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}" logger.info(error_message) # Save error info with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: f.write(error_message) if hasattr(torch.cuda, "empty_cache"): torch.cuda.empty_cache() return sample_name, seed def main(configs: Any) -> None: # Runner runner = InferenceRunner(configs) infer_predict(runner, configs) def run() -> None: LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" logging.basicConfig( format=LOG_FORMAT, level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", filemode="w", ) configs_base["use_deepspeed_evo_attention"] = ( os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true" ) configs = {**configs_base, **{"data": data_configs}, **inference_configs} configs = parse_configs( configs=configs, arg_str=parse_sys_args(), fill_required_with_null=True, ) download_infercence_cache(configs, model_version="v0.2.0") main(configs) if __name__ == "__main__": run()