import logging |
import os |
import gdown |
import traceback |
import urllib.request |
from contextlib import nullcontext |
from os.path import exists as opexists |
from os.path import join as opjoin |
from typing import Any, Mapping |
import torch |
import torch.distributed as dist |
from torch.utils.data import DataLoader |
from configs.configs_base import configs as configs_base |
from configs.configs_data import data_configs |
from configs.configs_inference import inference_configs |
from runner.dumper import DataDumper |
from protenix.config import parse_configs, parse_sys_args |
from protenix.data.infer_data_pipeline import get_inference_dataloader |
from protenix.model.protenix_edit import Protenix |
from protenix.utils.distributed import DIST_WRAPPER |
from protenix.utils.seed import seed_everything |
from protenix.utils.torch_utils import to_device |
from protenix.data.dataset import BaseSingleDataset |
from protenix.web_service.dependency_url import URL |
logger = logging.getLogger(__name__) |
def get_recovery(pred_code, gt_code): |
pred_code = torch.clamp(pred_code, min=-10, max=10) |
epsilon = 1e-6 |
predicted_classes = (torch.sigmoid(pred_code) > 0.5).float() |
recovery = ((predicted_classes == gt_code.float()).float().mean(dim=-1)).clamp(min=epsilon) |
return recovery.mean() |
class InferenceRunner(object): |
def __init__(self, configs: Any) -> None: |
self.configs = configs |
self.init_env() |
self.init_basics() |
self.init_model() |
self.load_checkpoint() |
self.init_dumper( |
need_atom_confidence=configs.need_atom_confidence, |
sorted_by_ranking_score=configs.sorted_by_ranking_score, |
) |
def init_env(self) -> None: |
self.print( |
f"Distributed environment: world size: {DIST_WRAPPER.world_size}, " |
+ f"global rank: {DIST_WRAPPER.rank}, local rank: {DIST_WRAPPER.local_rank}" |
) |
self.use_cuda = torch.cuda.device_count() > 0 |
if self.use_cuda: |
self.device = torch.device("cuda:{}".format(DIST_WRAPPER.local_rank)) |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count())) |
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) |
logging.info( |
f"LOCAL_RANK: {DIST_WRAPPER.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]" |
) |
torch.cuda.set_device(self.device) |
else: |
self.device = torch.device("cpu") |
if DIST_WRAPPER.world_size > 1: |
dist.init_process_group(backend="nccl") |
if self.configs.use_deepspeed_evo_attention: |
env = os.getenv("CUTLASS_PATH", None) |
self.print(f"env: {env}") |
assert ( |
env is not None |
), "if use ds4sci, set `CUTLASS_PATH` env as https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/" |
if env is not None: |
logging.info( |
"The kernels will be compiled when DS4Sci_EvoformerAttention is called for the first time." |
) |
use_fastlayernorm = os.getenv("LAYERNORM_TYPE", None) |
if use_fastlayernorm == "fast_layernorm": |
logging.info( |
"The kernels will be compiled when fast_layernorm is called for the first time." |
) |
logging.info("Finished init ENV.") |
def init_basics(self) -> None: |
self.dump_dir = self.configs.dump_dir |
self.error_dir = opjoin(self.dump_dir, "ERR") |
os.makedirs(self.dump_dir, exist_ok=True) |
os.makedirs(self.error_dir, exist_ok=True) |
def init_model(self) -> None: |
self.model = Protenix(self.configs).to(self.device) |
def load_checkpoint(self) -> None: |
checkpoint_path = self.configs.load_checkpoint_path |
if not os.path.exists(checkpoint_path): |
raise Exception(f"Given checkpoint path not exist [{checkpoint_path}]") |
self.print( |
f"Loading from {checkpoint_path}, strict: {self.configs.load_strict}" |
) |
checkpoint = torch.load(checkpoint_path, self.device) |
sample_key = [k for k in checkpoint["model"].keys()][0] |
self.print(f"Sampled key: {sample_key}") |
if sample_key.startswith("module."): |
checkpoint["model"] = { |
k[len("module.") :]: v for k, v in checkpoint["model"].items() |
} |
self.model.load_state_dict( |
state_dict=checkpoint["model"], |
strict=self.configs.load_strict, |
) |
self.model.eval() |
self.print(f"Finish loading checkpoint.") |
def init_dumper( |
self, need_atom_confidence: bool = False, sorted_by_ranking_score: bool = True |
): |
self.dumper = DataDumper( |
base_dir=self.dump_dir, |
need_atom_confidence=need_atom_confidence, |
sorted_by_ranking_score=sorted_by_ranking_score, |
) |
@torch.no_grad() |
def predict(self, data: Mapping[str, Mapping[str, Any]], watermark=False) -> dict[str, torch.Tensor]: |
eval_precision = { |
"fp32": torch.float32, |
"bf16": torch.bfloat16, |
"fp16": torch.float16, |
}[self.configs.dtype] |
enable_amp = ( |
torch.autocast(device_type="cuda", dtype=eval_precision) |
if torch.cuda.is_available() |
else nullcontext() |
) |
data = to_device(data, self.device) |
with enable_amp: |
prediction, label_dict, _ = self.model( |
input_feature_dict=data["input_feature_dict"], |
label_full_dict=None, |
label_dict=None, |
mode="inference", |
watermark=watermark |
) |
return prediction, label_dict |
@torch.no_grad() |
def detect(self, data: Mapping[str, Mapping[str, Any]]) -> dict[str, torch.Tensor]: |
eval_precision = { |
"fp32": torch.float32, |
"bf16": torch.bfloat16, |
"fp16": torch.float16, |
}[self.configs.dtype] |
enable_amp = ( |
torch.autocast(device_type="cuda", dtype=eval_precision) |
if torch.cuda.is_available() |
else nullcontext() |
) |
data = to_device(data, self.device) |
with enable_amp: |
prediction, label_dict, _ = self.model( |
input_feature_dict=data["input_feature_dict"], |
label_dict=data["label_dict"], |
label_full_dict=data["label_full_dict"], |
mode="inference", |
detect=True |
) |
return prediction, label_dict |
def print(self, msg: str): |
if DIST_WRAPPER.rank == 0: |
logger.info(msg) |
def update_model_configs(self, new_configs: Any) -> None: |
self.model.configs = new_configs |
def download_infercence_cache() -> None: |
code_directory = './' |
data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache") |
os.makedirs(data_cache_dir, exist_ok=True) |
if not os.path.exists('./release_data/ccd_cache/components.v20240608.cif'): |
file_id = '1e8wxpuEB-0xL_3dlMfZCFo6cL5oSHSUK' |
download_url = f'https://drive.google.com/uc?id={file_id}' |
output_file = './release_data/ccd_cache/components.v20240608.cif' |
gdown.download(download_url, output_file, quiet=False) |
if not os.path.exists('./release_data/ccd_cache/components.v20240608.cif.rdkit_mol.pkl'): |
file_id = '1R9d678aBfQwTd0Rh15doRmW-fETNdeWf' |
download_url = f'https://drive.google.com/uc?id={file_id}' |
output_file = './release_data/ccd_cache/components.v20240608.cif.rdkit_mol.pkl' |
gdown.download(download_url, output_file, quiet=False) |
if not os.path.exists('./checkpoint.pt'): |
file_id = '17zBIRed3xZM8ux0bq2hpf1oFC75Y7OEw' |
url = f'https://drive.google.com/uc?id={file_id}' |
gdown.download(url, './checkpoint.pt', quiet=False) |
def update_inference_configs(configs: Any, N_token: int): |
if N_token > 3840: |
configs.skip_amp.confidence_head = False |
configs.skip_amp.sample_diffusion = False |
elif N_token > 2560: |
configs.skip_amp.confidence_head = False |
configs.skip_amp.sample_diffusion = True |
else: |
configs.skip_amp.confidence_head = True |
configs.skip_amp.sample_diffusion = True |
return configs |
def infer_detect(runner: InferenceRunner, configs: Any) -> None: |
try: |
data_config = configs.data |
config_dict = data_config["recentPDB_1536_sample384_0925"].to_dict() |
config_dict['base_info']['bioassembly_dict_dir']=configs.subdir |
config_dict['base_info']['indices_fpath']=os.path.join(configs.subdir, 'output.csv') |
config_dict['base_info']['pdb_list']="" |
params = { |
"name": 'detection_data', |
**config_dict["base_info"], |
"cropping_configs": config_dict["cropping_configs"], |
"error_dir": configs.subdir, |
"msa_featurizer": None, |
"template_featurizer": None, |
"lig_atom_rename": False, |
"shuffle_mols": False, |
"shuffle_sym_ids": False, |
} |
test_dataset = BaseSingleDataset(**params) |
test_sampler = None |
dataloader = DataLoader( |
test_dataset, |
batch_size=1, |
shuffle=False, |
num_workers=0, |
sampler=test_sampler, |
collate_fn=lambda batch: batch[0], |
) |
except Exception as e: |
error_message = f"{e}:\n{traceback.format_exc()}" |
logger.info(error_message) |
with open(opjoin(runner.error_dir, "error.txt"), "a") as f: |
f.write(error_message) |
return |
for batch in dataloader: |
prediction, label_dict = runner.detect(batch) |
pred_code = torch.clamp(prediction['watermark'], min=-10, max=10) |
predicted_class = torch.sigmoid(pred_code).item() > 0.5 |
return predicted_class and configs.process_success |
def infer_predict(runner: InferenceRunner, configs: Any) -> None: |
logger.info(f"Loading data from\n{configs.input_json_path}") |
try: |
dataloader = get_inference_dataloader(configs=configs) |
except Exception as e: |
error_message = f"{e}:\n{traceback.format_exc()}" |
logger.info(error_message) |
with open(opjoin(runner.error_dir, "error.txt"), "a") as f: |
f.write(error_message) |
return |
num_data = len(dataloader.dataset) |
for seed in configs.seeds: |
seed_everything(seed=seed, deterministic=configs.deterministic) |
for batch in dataloader: |
try: |
data, atom_array, data_error_message = batch[0] |
sample_name = data["sample_name"] |
if len(data_error_message) > 0: |
logger.info(data_error_message) |
with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: |
f.write(data_error_message) |
continue |
logger.info( |
( |
f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: " |
f"N_asym {data['N_asym'].item()}, N_token {data['N_token'].item()}, " |
f"N_atom {data['N_atom'].item()}, N_msa {data['N_msa'].item()}" |
) |
) |
new_configs = update_inference_configs(configs, data["N_token"].item()) |
runner.update_model_configs(new_configs) |
prediction, label_dict = runner.predict(data, configs.watermark) |
runner.dumper.dump( |
dataset_name="", |
pdb_id=sample_name, |
seed=seed, |
pred_dict=prediction, |
atom_array=atom_array, |
entity_poly_type=data["entity_poly_type"], |
saved_path=configs.saved_path, |
) |
logger.info( |
f"[Rank {DIST_WRAPPER.rank}] {data['sample_name']} succeeded.\n" |
f"Results saved to {configs.dump_dir}" |
) |
torch.cuda.empty_cache() |
except Exception as e: |
error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}" |
logger.info(error_message) |
with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: |
f.write(error_message) |
if hasattr(torch.cuda, "empty_cache"): |
torch.cuda.empty_cache() |
return sample_name, seed |
def main(configs: Any) -> None: |
runner = InferenceRunner(configs) |
infer_predict(runner, configs) |
def run() -> None: |
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" |
logging.basicConfig( |
format=LOG_FORMAT, |
level=logging.INFO, |
datefmt="%Y-%m-%d %H:%M:%S", |
filemode="w", |
) |
configs_base["use_deepspeed_evo_attention"] = ( |
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true" |
) |
configs = {**configs_base, **{"data": data_configs}, **inference_configs} |
configs = parse_configs( |
configs=configs, |
arg_str=parse_sys_args(), |
fill_required_with_null=True, |
) |
download_infercence_cache(configs, model_version="v0.2.0") |
main(configs) |
if __name__ == "__main__": |
run() |