import os import time import copy from pathlib import Path from math import ceil from contextlib import contextmanager, nullcontext from functools import partial, wraps from collections.abc import Iterable import torch from torch import nn import torch.nn.functional as F from torch.utils.data import random_split, DataLoader from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from torch.cuda.amp import autocast, GradScaler import pytorch_warmup as warmup import shutil import esm from einops import rearrange from packaging import version __version__ = '1.9.3' device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu" ) import matplotlib.pyplot as plt def cycle(dl): while True: for data in dl: yield data from packaging import version import numpy as np from ema_pytorch import EMA from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs from fsspec.core import url_to_fs from fsspec.implementations.local import LocalFileSystem # # -- # from PD_SpLMxDiff.ModelPack import resize_image_to,ProteinDesigner_B # from PD_SpLMxDiff.UtilityPack import get_Model_A_error, convert_into_tokens # from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec,decode_many_ems_token_rec # from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec_for_folding,decode_many_ems_token_rec_for_folding # from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec_for_folding_with_mask,decode_many_ems_token_rec_for_folding_with_mask,read_mask_from_input # from PD_SpLMxDiff.UtilityPack import get_DSSP_result, string_diff # from PD_SpLMxDiff.DataSetPack import pad_a_np_arr # ++ from PD_pLMProbXDiff.ModelPack import ( resize_image_to, ProteinDesigner_B, ) from PD_pLMProbXDiff.UtilityPack import ( get_Model_A_error, convert_into_tokens,convert_into_tokens_using_prob, decode_one_ems_token_rec, decode_many_ems_token_rec, decode_one_ems_token_rec_for_folding, decode_many_ems_token_rec_for_folding, decode_one_ems_token_rec_for_folding_with_mask, decode_many_ems_token_rec_for_folding_with_mask, read_mask_from_input, get_DSSP_result, string_diff, load_in_pLM, ) from PD_pLMProbXDiff.DataSetPack import ( pad_a_np_arr ) # loss function criterion_MSE_sum = nn.MSELoss(reduction='sum') criterion_MAE_sum = nn.L1Loss(reduction='sum') # helper functions def exists(val): return val is not None def default(val, d): if exists(val): return val return d() if callable(d) else d def cast_tuple(val, length = 1): if isinstance(val, list): val = tuple(val) return val if isinstance(val, tuple) else ((val,) * length) def find_first(fn, arr): for ind, el in enumerate(arr): if fn(el): return ind return -1 def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) def group_dict_by_key(cond, d): return_val = [dict(),dict()] for key in d.keys(): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] return (*return_val,) def string_begins_with(prefix, str): return str.startswith(prefix) def group_by_key_prefix(prefix, d): return group_dict_by_key(partial(string_begins_with, prefix), d) def groupby_prefix_and_trim(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr # url to fs, bucket, path - for checkpointing to cloud def url_to_bucket(url): if '://' not in url: return url _, suffix = url.split('://') if prefix in {'gs', 's3'}: return suffix.split('/')[0] else: raise ValueError(f'storage type prefix "{prefix}" is not supported yet') # decorators def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner def cast_torch_tensor(fn, cast_fp16 = False): @wraps(fn) def inner(model, *args, **kwargs): device = kwargs.pop('_device', model.device) cast_device = kwargs.pop('_cast_device', True) should_cast_fp16 = cast_fp16 and model.cast_half_at_training kwargs_keys = kwargs.keys() all_args = (*args, *kwargs.values()) split_kwargs_index = len(all_args) - len(kwargs_keys) all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) if cast_device: all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) if should_cast_fp16: all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args)) args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) out = fn(model, *args, **kwargs) return out return inner # gradient accumulation functions def split_iterable(it, split_size): accum = [] for ind in range(ceil(len(it) / split_size)): start_index = ind * split_size accum.append(it[start_index: (start_index + split_size)]) return accum def split(t, split_size = None): if not exists(split_size): return t if isinstance(t, torch.Tensor): return t.split(split_size, dim = 0) if isinstance(t, Iterable): return split_iterable(t, split_size) return TypeError def find_first(cond, arr): for el in arr: if cond(el): return el return None def split_args_and_kwargs(*args, split_size = None, **kwargs): all_args = (*args, *kwargs.values()) len_all_args = len(all_args) first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) assert exists(first_tensor) batch_size = len(first_tensor) split_size = default(split_size, batch_size) num_chunks = ceil(batch_size / split_size) dict_len = len(kwargs) dict_keys = kwargs.keys() split_kwargs_index = len_all_args - dict_len split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] chunk_sizes = tuple(map(len, split_all_args[0])) for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) chunk_size_frac = chunk_size / batch_size yield chunk_size_frac, (chunked_args, chunked_kwargs) # imagen trainer def imagen_sample_in_chunks(fn): @wraps(fn) def inner(self, *args, max_batch_size = None, **kwargs): if not exists(max_batch_size): return fn(self, *args, **kwargs) if self.imagen.unconditional: batch_size = kwargs.get('batch_size') batch_sizes = num_to_groups(batch_size, max_batch_size) outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] else: outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] if isinstance(outputs[0], torch.Tensor): return torch.cat(outputs, dim = 0) return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs)))) return inner def restore_parts(state_dict_target, state_dict_from): for name, param in state_dict_from.items(): if name not in state_dict_target: continue if param.size() == state_dict_target[name].size(): state_dict_target[name].copy_(param) else: print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}") return state_dict_target class ImagenTrainer(nn.Module): locked = False def __init__( self, #imagen = None, model = None, imagen_checkpoint_path = None, use_ema = True, lr = 1e-4, eps = 1e-8, beta1 = 0.9, beta2 = 0.99, max_grad_norm = None, group_wd_params = True, warmup_steps = None, cosine_decay_max_steps = None, only_train_unet_number = None, fp16 = False, precision = None, split_batches = True, dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), verbose = True, split_valid_fraction = 0.025, split_valid_from_train = False, split_random_seed = 42, checkpoint_path = None, checkpoint_every = None, checkpoint_fs = None, fs_kwargs: dict = None, max_checkpoints_keep = 20, # +++++++++++++++++++++ CKeys=None, # **kwargs ): super().__init__() assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' # determine filesystem, using fsspec, for saving to local filesystem or cloud self.fs = checkpoint_fs if not exists(self.fs): fs_kwargs = default(fs_kwargs, {}) self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) # # ----------------------------------- # # from MJB # assert isinstance(model.imagen, (ProteinDesigner_B)) # modified by BN # ++: try this trainer for all models # assert isinstance(model, (ProteinDesigner_B)) # +++++++++++++++++++++++++ self.CKeys = CKeys ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) self.imagen = model.imagen self.model=model self.is_elucidated = self.model.is_elucidated # create accelerator instance accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') self.accelerator = Accelerator(**{ 'split_batches': split_batches, 'mixed_precision': accelerator_mixed_precision, 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] , **accelerate_kwargs}) ImagenTrainer.locked = self.is_distributed # cast data to fp16 at training time if needed self.cast_half_at_training = accelerator_mixed_precision == 'fp16' # grad scaler must be managed outside of accelerator grad_scaler_enabled = fp16 self.num_unets = len(self.imagen.unets) self.use_ema = use_ema and self.is_main self.ema_unets = nn.ModuleList([]) # keep track of what unet is being trained on # only going to allow 1 unet training at a time self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on # data related functions self.train_dl_iter = None self.train_dl = None self.valid_dl_iter = None self.valid_dl = None self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names # auto splitting validation from training, if dataset is passed in self.split_valid_from_train = split_valid_from_train assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' self.split_valid_fraction = split_valid_fraction self.split_random_seed = split_random_seed # be able to finely customize learning rate, weight decay # per unet lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): optimizer = Adam( unet.parameters(), lr = unet_lr, eps = unet_eps, betas = (beta1, beta2), **kwargs ) if self.use_ema: self.ema_unets.append(EMA(unet, **ema_kwargs)) scaler = GradScaler(enabled = grad_scaler_enabled) scheduler = warmup_scheduler = None if exists(unet_cosine_decay_max_steps): scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) if exists(unet_warmup_steps): warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if not exists(scheduler): scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) # set on object setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers setattr(self, f'scaler{ind}', scaler) setattr(self, f'scheduler{ind}', scheduler) setattr(self, f'warmup{ind}', warmup_scheduler) # gradient clipping if needed self.max_grad_norm = max_grad_norm # step tracker and misc self.register_buffer('steps', torch.tensor([0] * self.num_unets)) self.verbose = verbose # automatic set devices based on what accelerator decided self.imagen.to(self.device) self.to(self.device) # checkpointing assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) self.checkpoint_path = checkpoint_path self.checkpoint_every = checkpoint_every self.max_checkpoints_keep = max_checkpoints_keep self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main if exists(checkpoint_path) and self.can_checkpoint: bucket = url_to_bucket(checkpoint_path) if not self.fs.exists(bucket): self.fs.mkdir(bucket) self.load_from_checkpoint_folder() # only allowing training for unet self.only_train_unet_number = only_train_unet_number self.validate_and_set_unet_being_trained(only_train_unet_number) # computed values @property def device(self): return self.accelerator.device @property def is_distributed(self): return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) @property def is_main(self): return self.accelerator.is_main_process @property def is_local_main(self): return self.accelerator.is_local_main_process @property def unwrapped_unet(self): return self.accelerator.unwrap_model(self.unet_being_trained) # optimizer helper functions def get_lr(self, unet_number): self.validate_unet_number(unet_number) unet_index = unet_number - 1 optim = getattr(self, f'optim{unet_index}') return optim.param_groups[0]['lr'] # function for allowing only one unet from being trained at a time def validate_and_set_unet_being_trained(self, unet_number = None): if exists(unet_number): self.validate_unet_number(unet_number) assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' self.only_train_unet_number = unet_number self.imagen.only_train_unet_number = unet_number if not exists(unet_number): return self.wrap_unet(unet_number) def wrap_unet(self, unet_number): if hasattr(self, 'one_unet_wrapped'): return unet = self.imagen.get_unet(unet_number) self.unet_being_trained = self.accelerator.prepare(unet) unet_index = unet_number - 1 optimizer = getattr(self, f'optim{unet_index}') scheduler = getattr(self, f'scheduler{unet_index}') optimizer = self.accelerator.prepare(optimizer) if exists(scheduler): scheduler = self.accelerator.prepare(scheduler) setattr(self, f'optim{unet_index}', optimizer) setattr(self, f'scheduler{unet_index}', scheduler) self.one_unet_wrapped = True # hacking accelerator due to not having separate gradscaler per optimizer def set_accelerator_scaler(self, unet_number): unet_number = self.validate_unet_number(unet_number) scaler = getattr(self, f'scaler{unet_number - 1}') self.accelerator.scaler = scaler for optimizer in self.accelerator._optimizers: optimizer.scaler = scaler # helper print def print(self, msg): if not self.is_main: return if not self.verbose: return return self.accelerator.print(msg) # validating the unet number def validate_unet_number(self, unet_number = None): if self.num_unets == 1: unet_number = default(unet_number, 1) assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' return unet_number # number of training steps taken def num_steps_taken(self, unet_number = None): if self.num_unets == 1: unet_number = default(unet_number, 1) return self.steps[unet_number - 1].item() def print_untrained_unets(self): print_final_error = False for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): if steps > 0 or isinstance(unet, NullUnet): continue self.print(f'unet {ind + 1} has not been trained') print_final_error = True if print_final_error: self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') # data related functions def add_train_dataloader(self, dl = None): if not exists(dl): return assert not exists(self.train_dl), 'training dataloader was already added' self.train_dl = self.accelerator.prepare(dl) def add_valid_dataloader(self, dl): if not exists(dl): return assert not exists(self.valid_dl), 'validation dataloader was already added' self.valid_dl = self.accelerator.prepare(dl) def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): if not exists(ds): return assert not exists(self.train_dl), 'training dataloader was already added' valid_ds = None if self.split_valid_from_train: train_size = int((1 - self.split_valid_fraction) * len(ds)) valid_size = len(ds) - train_size ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) self.train_dl = self.accelerator.prepare(dl) if not self.split_valid_from_train: return self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): if not exists(ds): return assert not exists(self.valid_dl), 'validation dataloader was already added' dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) self.valid_dl = self.accelerator.prepare(dl) def create_train_iter(self): assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' if exists(self.train_dl_iter): return self.train_dl_iter = cycle(self.train_dl) def create_valid_iter(self): assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' if exists(self.valid_dl_iter): return self.valid_dl_iter = cycle(self.valid_dl) def train_step(self, unet_number = None, **kwargs): self.create_train_iter() loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) self.update(unet_number = unet_number) return loss @torch.no_grad() @eval_decorator def valid_step(self, **kwargs): self.create_valid_iter() context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext with context(): loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) return loss def step_with_dl_iter(self, dl_iter, **kwargs): dl_tuple_output = cast_tuple(next(dl_iter)) model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) loss = self.forward(**{**kwargs, **model_input}) return loss # checkpointing functions @property def all_checkpoints_sorted(self): glob_pattern = os.path.join(self.checkpoint_path, '*.pt') checkpoints = self.fs.glob(glob_pattern) sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) return sorted_checkpoints def load_from_checkpoint_folder(self, last_total_steps = -1): if last_total_steps != -1: filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') self.load(filepath) return sorted_checkpoints = self.all_checkpoints_sorted if len(sorted_checkpoints) == 0: self.print(f'no checkpoints found to load from at {self.checkpoint_path}') return last_checkpoint = sorted_checkpoints[0] self.load(last_checkpoint) def save_to_checkpoint_folder(self): self.accelerator.wait_for_everyone() if not self.can_checkpoint: return total_steps = int(self.steps.sum().item()) filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') self.save(filepath) if self.max_checkpoints_keep <= 0: return sorted_checkpoints = self.all_checkpoints_sorted checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] for checkpoint in checkpoints_to_discard: self.fs.rm(checkpoint) # saving and loading functions def save( self, path, overwrite = True, without_optim_and_sched = False, **kwargs ): self.accelerator.wait_for_everyone() if not self.can_checkpoint: return fs = self.fs assert not (fs.exists(path) and not overwrite) self.reset_ema_unets_all_one_device() save_obj = dict( model = self.imagen.state_dict(), version = __version__, steps = self.steps.cpu(), **kwargs ) save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() for ind in save_optim_and_sched_iter: scaler_key = f'scaler{ind}' optimizer_key = f'optim{ind}' scheduler_key = f'scheduler{ind}' warmup_scheduler_key = f'warmup{ind}' scaler = getattr(self, scaler_key) optimizer = getattr(self, optimizer_key) scheduler = getattr(self, scheduler_key) warmup_scheduler = getattr(self, warmup_scheduler_key) if exists(scheduler): save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} if exists(warmup_scheduler): save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} if self.use_ema: save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} # determine if imagen config is available if hasattr(self.imagen, '_config'): self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""') save_obj = { **save_obj, 'imagen_type': 'elucidated' if self.is_elucidated else 'original', 'imagen_params': self.imagen._config } #save to path with fs.open(path, 'wb') as f: torch.save(save_obj, f) self.print(f'checkpoint saved to {path}') def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): fs = self.fs if noop_if_not_exist and not fs.exists(path): self.print(f'trainer checkpoint not found at {str(path)}') return assert fs.exists(path), f'{path} does not exist' self.reset_ema_unets_all_one_device() # to avoid extra GPU memory usage in main process when using Accelerate with fs.open(path) as f: loaded_obj = torch.load(f, map_location='cpu') if version.parse(__version__) != version.parse(loaded_obj['version']): self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') try: self.imagen.load_state_dict(loaded_obj['model'], strict = strict) except RuntimeError: print("Failed loading state dict. Trying partial load") self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), loaded_obj['model'])) if only_model: return loaded_obj self.steps.copy_(loaded_obj['steps']) for ind in range(0, self.num_unets): scaler_key = f'scaler{ind}' optimizer_key = f'optim{ind}' scheduler_key = f'scheduler{ind}' warmup_scheduler_key = f'warmup{ind}' scaler = getattr(self, scaler_key) optimizer = getattr(self, optimizer_key) scheduler = getattr(self, scheduler_key) warmup_scheduler = getattr(self, warmup_scheduler_key) if exists(scheduler) and scheduler_key in loaded_obj: scheduler.load_state_dict(loaded_obj[scheduler_key]) if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) if exists(optimizer): try: optimizer.load_state_dict(loaded_obj[optimizer_key]) scaler.load_state_dict(loaded_obj[scaler_key]) except: self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') if self.use_ema: assert 'ema' in loaded_obj try: self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) except RuntimeError: print("Failed loading state dict. Trying partial load") self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), loaded_obj['ema'])) self.print(f'checkpoint loaded from {path}') return loaded_obj # managing ema unets and their devices @property def unets(self): return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) def get_ema_unet(self, unet_number = None): if not self.use_ema: return unet_number = self.validate_unet_number(unet_number) index = unet_number - 1 if isinstance(self.unets, nn.ModuleList): unets_list = [unet for unet in self.ema_unets] delattr(self, 'ema_unets') self.ema_unets = unets_list if index != self.ema_unet_being_trained_index: for unet_index, unet in enumerate(self.ema_unets): unet.to(self.device if unet_index == index else 'cpu') self.ema_unet_being_trained_index = index return self.ema_unets[index] def reset_ema_unets_all_one_device(self, device = None): if not self.use_ema: return device = default(device, self.device) self.ema_unets = nn.ModuleList([*self.ema_unets]) self.ema_unets.to(device) self.ema_unet_being_trained_index = -1 @torch.no_grad() @contextmanager def use_ema_unets(self): if not self.use_ema: output = yield return output self.reset_ema_unets_all_one_device() self.imagen.reset_unets_all_one_device() self.unets.eval() trainable_unets = self.imagen.unets self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling output = yield self.imagen.unets = trainable_unets # restore original training unets # cast the ema_model unets back to original device for ema in self.ema_unets: ema.restore_ema_model_device() return output def print_unet_devices(self): self.print('unet devices:') for i, unet in enumerate(self.imagen.unets): device = next(unet.parameters()).device self.print(f'\tunet {i}: {device}') if not self.use_ema: return self.print('\nema unet devices:') for i, ema_unet in enumerate(self.ema_unets): device = next(ema_unet.parameters()).device self.print(f'\tema unet {i}: {device}') # overriding state dict functions def state_dict(self, *args, **kwargs): self.reset_ema_unets_all_one_device() return super().state_dict(*args, **kwargs) def load_state_dict(self, *args, **kwargs): self.reset_ema_unets_all_one_device() return super().load_state_dict(*args, **kwargs) # encoding text functions def encode_text(self, text, **kwargs): return self.imagen.encode_text(text, **kwargs) # forwarding functions and gradient step updates def update(self, unet_number = None): unet_number = self.validate_unet_number(unet_number) self.validate_and_set_unet_being_trained(unet_number) self.set_accelerator_scaler(unet_number) index = unet_number - 1 unet = self.unet_being_trained optimizer = getattr(self, f'optim{index}') scaler = getattr(self, f'scaler{index}') scheduler = getattr(self, f'scheduler{index}') warmup_scheduler = getattr(self, f'warmup{index}') # set the grad scaler on the accelerator, since we are managing one per u-net if exists(self.max_grad_norm): self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) optimizer.step() optimizer.zero_grad() if self.use_ema: ema_unet = self.get_ema_unet(unet_number) ema_unet.update() # scheduler, if needed maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() with maybe_warmup_context: if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs scheduler.step() self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) if not exists(self.checkpoint_path): return total_steps = int(self.steps.sum().item()) if total_steps % self.checkpoint_every: return self.save_to_checkpoint_folder() @torch.no_grad() @cast_torch_tensor @imagen_sample_in_chunks def sample(self, *args, **kwargs): context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets self.print_untrained_unets() if not self.is_main: kwargs['use_tqdm'] = False with context(): output = self.imagen.sample(*args, device = self.device, **kwargs) return output @partial(cast_torch_tensor, cast_fp16 = True) def forward( self, *args, unet_number = None, max_batch_size = None, **kwargs ): unet_number = self.validate_unet_number(unet_number) self.validate_and_set_unet_being_trained(unet_number) self.set_accelerator_scaler(unet_number) assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' total_loss = 0. # + for debug if self.CKeys['Debug_TrainerPack']==1: print("In Trainer:Forward, check inputs:") print('args: ', len(args)) print('args in:',args[0].shape) print('kwargs: ', kwargs.keys()) for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): # + for debug if self.CKeys['Debug_TrainerPack']==1: print("after chunks,...") print('chun_frac: ', chunk_size_frac) print('chun_args: ', chunked_args) print('chun_kwargs: ', chunked_kwargs) with self.accelerator.autocast(): loss = self.model( *chunked_args, unet_number = unet_number, **chunked_kwargs ) loss = loss * chunk_size_frac # + for debug if self.CKeys['Debug_TrainerPack']==1: print('part chun loss: ', loss) total_loss += loss#.item() if self.training: self.accelerator.backward(loss) return total_loss # ======================================================== # class ImagenTrainer_ModelB(nn.Module): locked = False def __init__( self, #imagen = None, model = None, imagen_checkpoint_path = None, use_ema = True, lr = 1e-4, eps = 1e-8, beta1 = 0.9, beta2 = 0.99, max_grad_norm = None, group_wd_params = True, warmup_steps = None, cosine_decay_max_steps = None, only_train_unet_number = None, fp16 = False, precision = None, split_batches = True, dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), verbose = True, split_valid_fraction = 0.025, split_valid_from_train = False, split_random_seed = 42, checkpoint_path = None, checkpoint_every = None, checkpoint_fs = None, fs_kwargs: dict = None, max_checkpoints_keep = 20, # +++++++++++++++++++++ CKeys=None, # **kwargs ): super().__init__() assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' # determine filesystem, using fsspec, for saving to local filesystem or cloud self.fs = checkpoint_fs if not exists(self.fs): fs_kwargs = default(fs_kwargs, {}) self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) # # ----------------------------------- # # from MJB # assert isinstance(model.imagen, (ProteinDesigner_B)) # modified by BN # ++ assert isinstance(model, (ProteinDesigner_B)) # +++++++++++++++++++++++++ self.CKeys = CKeys ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) self.imagen = model.imagen self.model=model self.is_elucidated = self.model.is_elucidated # create accelerator instance accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') self.accelerator = Accelerator(**{ 'split_batches': split_batches, 'mixed_precision': accelerator_mixed_precision, 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] , **accelerate_kwargs}) ImagenTrainer.locked = self.is_distributed # cast data to fp16 at training time if needed self.cast_half_at_training = accelerator_mixed_precision == 'fp16' # grad scaler must be managed outside of accelerator grad_scaler_enabled = fp16 self.num_unets = len(self.imagen.unets) self.use_ema = use_ema and self.is_main self.ema_unets = nn.ModuleList([]) # keep track of what unet is being trained on # only going to allow 1 unet training at a time self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on # data related functions self.train_dl_iter = None self.train_dl = None self.valid_dl_iter = None self.valid_dl = None self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names # auto splitting validation from training, if dataset is passed in self.split_valid_from_train = split_valid_from_train assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' self.split_valid_fraction = split_valid_fraction self.split_random_seed = split_random_seed # be able to finely customize learning rate, weight decay # per unet lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): optimizer = Adam( unet.parameters(), lr = unet_lr, eps = unet_eps, betas = (beta1, beta2), **kwargs ) if self.use_ema: self.ema_unets.append(EMA(unet, **ema_kwargs)) scaler = GradScaler(enabled = grad_scaler_enabled) scheduler = warmup_scheduler = None if exists(unet_cosine_decay_max_steps): scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) if exists(unet_warmup_steps): warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if not exists(scheduler): scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) # set on object setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers setattr(self, f'scaler{ind}', scaler) setattr(self, f'scheduler{ind}', scheduler) setattr(self, f'warmup{ind}', warmup_scheduler) # gradient clipping if needed self.max_grad_norm = max_grad_norm # step tracker and misc self.register_buffer('steps', torch.tensor([0] * self.num_unets)) self.verbose = verbose # automatic set devices based on what accelerator decided self.imagen.to(self.device) self.to(self.device) # checkpointing assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) self.checkpoint_path = checkpoint_path self.checkpoint_every = checkpoint_every self.max_checkpoints_keep = max_checkpoints_keep self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main if exists(checkpoint_path) and self.can_checkpoint: bucket = url_to_bucket(checkpoint_path) if not self.fs.exists(bucket): self.fs.mkdir(bucket) self.load_from_checkpoint_folder() # only allowing training for unet self.only_train_unet_number = only_train_unet_number self.validate_and_set_unet_being_trained(only_train_unet_number) # computed values @property def device(self): return self.accelerator.device @property def is_distributed(self): return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) @property def is_main(self): return self.accelerator.is_main_process @property def is_local_main(self): return self.accelerator.is_local_main_process @property def unwrapped_unet(self): return self.accelerator.unwrap_model(self.unet_being_trained) # optimizer helper functions def get_lr(self, unet_number): self.validate_unet_number(unet_number) unet_index = unet_number - 1 optim = getattr(self, f'optim{unet_index}') return optim.param_groups[0]['lr'] # function for allowing only one unet from being trained at a time def validate_and_set_unet_being_trained(self, unet_number = None): if exists(unet_number): self.validate_unet_number(unet_number) assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' self.only_train_unet_number = unet_number self.imagen.only_train_unet_number = unet_number if not exists(unet_number): return self.wrap_unet(unet_number) def wrap_unet(self, unet_number): if hasattr(self, 'one_unet_wrapped'): return unet = self.imagen.get_unet(unet_number) self.unet_being_trained = self.accelerator.prepare(unet) unet_index = unet_number - 1 optimizer = getattr(self, f'optim{unet_index}') scheduler = getattr(self, f'scheduler{unet_index}') optimizer = self.accelerator.prepare(optimizer) if exists(scheduler): scheduler = self.accelerator.prepare(scheduler) setattr(self, f'optim{unet_index}', optimizer) setattr(self, f'scheduler{unet_index}', scheduler) self.one_unet_wrapped = True # hacking accelerator due to not having separate gradscaler per optimizer def set_accelerator_scaler(self, unet_number): unet_number = self.validate_unet_number(unet_number) scaler = getattr(self, f'scaler{unet_number - 1}') self.accelerator.scaler = scaler for optimizer in self.accelerator._optimizers: optimizer.scaler = scaler # helper print def print(self, msg): if not self.is_main: return if not self.verbose: return return self.accelerator.print(msg) # validating the unet number def validate_unet_number(self, unet_number = None): if self.num_unets == 1: unet_number = default(unet_number, 1) assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' return unet_number # number of training steps taken def num_steps_taken(self, unet_number = None): if self.num_unets == 1: unet_number = default(unet_number, 1) return self.steps[unet_number - 1].item() def print_untrained_unets(self): print_final_error = False for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): if steps > 0 or isinstance(unet, NullUnet): continue self.print(f'unet {ind + 1} has not been trained') print_final_error = True if print_final_error: self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') # data related functions def add_train_dataloader(self, dl = None): if not exists(dl): return assert not exists(self.train_dl), 'training dataloader was already added' self.train_dl = self.accelerator.prepare(dl) def add_valid_dataloader(self, dl): if not exists(dl): return assert not exists(self.valid_dl), 'validation dataloader was already added' self.valid_dl = self.accelerator.prepare(dl) def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): if not exists(ds): return assert not exists(self.train_dl), 'training dataloader was already added' valid_ds = None if self.split_valid_from_train: train_size = int((1 - self.split_valid_fraction) * len(ds)) valid_size = len(ds) - train_size ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) self.train_dl = self.accelerator.prepare(dl) if not self.split_valid_from_train: return self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): if not exists(ds): return assert not exists(self.valid_dl), 'validation dataloader was already added' dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) self.valid_dl = self.accelerator.prepare(dl) def create_train_iter(self): assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' if exists(self.train_dl_iter): return self.train_dl_iter = cycle(self.train_dl) def create_valid_iter(self): assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' if exists(self.valid_dl_iter): return self.valid_dl_iter = cycle(self.valid_dl) def train_step(self, unet_number = None, **kwargs): self.create_train_iter() loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) self.update(unet_number = unet_number) return loss @torch.no_grad() @eval_decorator def valid_step(self, **kwargs): self.create_valid_iter() context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext with context(): loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) return loss def step_with_dl_iter(self, dl_iter, **kwargs): dl_tuple_output = cast_tuple(next(dl_iter)) model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) loss = self.forward(**{**kwargs, **model_input}) return loss # checkpointing functions @property def all_checkpoints_sorted(self): glob_pattern = os.path.join(self.checkpoint_path, '*.pt') checkpoints = self.fs.glob(glob_pattern) sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) return sorted_checkpoints def load_from_checkpoint_folder(self, last_total_steps = -1): if last_total_steps != -1: filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') self.load(filepath) return sorted_checkpoints = self.all_checkpoints_sorted if len(sorted_checkpoints) == 0: self.print(f'no checkpoints found to load from at {self.checkpoint_path}') return last_checkpoint = sorted_checkpoints[0] self.load(last_checkpoint) def save_to_checkpoint_folder(self): self.accelerator.wait_for_everyone() if not self.can_checkpoint: return total_steps = int(self.steps.sum().item()) filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') self.save(filepath) if self.max_checkpoints_keep <= 0: return sorted_checkpoints = self.all_checkpoints_sorted checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] for checkpoint in checkpoints_to_discard: self.fs.rm(checkpoint) # saving and loading functions def save( self, path, overwrite = True, without_optim_and_sched = False, **kwargs ): self.accelerator.wait_for_everyone() if not self.can_checkpoint: return fs = self.fs assert not (fs.exists(path) and not overwrite) self.reset_ema_unets_all_one_device() save_obj = dict( model = self.imagen.state_dict(), version = __version__, steps = self.steps.cpu(), **kwargs ) save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() for ind in save_optim_and_sched_iter: scaler_key = f'scaler{ind}' optimizer_key = f'optim{ind}' scheduler_key = f'scheduler{ind}' warmup_scheduler_key = f'warmup{ind}' scaler = getattr(self, scaler_key) optimizer = getattr(self, optimizer_key) scheduler = getattr(self, scheduler_key) warmup_scheduler = getattr(self, warmup_scheduler_key) if exists(scheduler): save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} if exists(warmup_scheduler): save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} if self.use_ema: save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} # determine if imagen config is available if hasattr(self.imagen, '_config'): self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""') save_obj = { **save_obj, 'imagen_type': 'elucidated' if self.is_elucidated else 'original', 'imagen_params': self.imagen._config } #save to path with fs.open(path, 'wb') as f: torch.save(save_obj, f) self.print(f'checkpoint saved to {path}') def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): fs = self.fs if noop_if_not_exist and not fs.exists(path): self.print(f'trainer checkpoint not found at {str(path)}') return assert fs.exists(path), f'{path} does not exist' self.reset_ema_unets_all_one_device() # to avoid extra GPU memory usage in main process when using Accelerate with fs.open(path) as f: loaded_obj = torch.load(f, map_location='cpu') if version.parse(__version__) != version.parse(loaded_obj['version']): self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') try: self.imagen.load_state_dict(loaded_obj['model'], strict = strict) except RuntimeError: print("Failed loading state dict. Trying partial load") self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), loaded_obj['model'])) if only_model: return loaded_obj self.steps.copy_(loaded_obj['steps']) for ind in range(0, self.num_unets): scaler_key = f'scaler{ind}' optimizer_key = f'optim{ind}' scheduler_key = f'scheduler{ind}' warmup_scheduler_key = f'warmup{ind}' scaler = getattr(self, scaler_key) optimizer = getattr(self, optimizer_key) scheduler = getattr(self, scheduler_key) warmup_scheduler = getattr(self, warmup_scheduler_key) if exists(scheduler) and scheduler_key in loaded_obj: scheduler.load_state_dict(loaded_obj[scheduler_key]) if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) if exists(optimizer): try: optimizer.load_state_dict(loaded_obj[optimizer_key]) scaler.load_state_dict(loaded_obj[scaler_key]) except: self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') if self.use_ema: assert 'ema' in loaded_obj try: self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) except RuntimeError: print("Failed loading state dict. Trying partial load") self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), loaded_obj['ema'])) self.print(f'checkpoint loaded from {path}') return loaded_obj # managing ema unets and their devices @property def unets(self): return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) def get_ema_unet(self, unet_number = None): if not self.use_ema: return unet_number = self.validate_unet_number(unet_number) index = unet_number - 1 if isinstance(self.unets, nn.ModuleList): unets_list = [unet for unet in self.ema_unets] delattr(self, 'ema_unets') self.ema_unets = unets_list if index != self.ema_unet_being_trained_index: for unet_index, unet in enumerate(self.ema_unets): unet.to(self.device if unet_index == index else 'cpu') self.ema_unet_being_trained_index = index return self.ema_unets[index] def reset_ema_unets_all_one_device(self, device = None): if not self.use_ema: return device = default(device, self.device) self.ema_unets = nn.ModuleList([*self.ema_unets]) self.ema_unets.to(device) self.ema_unet_being_trained_index = -1 @torch.no_grad() @contextmanager def use_ema_unets(self): if not self.use_ema: output = yield return output self.reset_ema_unets_all_one_device() self.imagen.reset_unets_all_one_device() self.unets.eval() trainable_unets = self.imagen.unets self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling output = yield self.imagen.unets = trainable_unets # restore original training unets # cast the ema_model unets back to original device for ema in self.ema_unets: ema.restore_ema_model_device() return output def print_unet_devices(self): self.print('unet devices:') for i, unet in enumerate(self.imagen.unets): device = next(unet.parameters()).device self.print(f'\tunet {i}: {device}') if not self.use_ema: return self.print('\nema unet devices:') for i, ema_unet in enumerate(self.ema_unets): device = next(ema_unet.parameters()).device self.print(f'\tema unet {i}: {device}') # overriding state dict functions def state_dict(self, *args, **kwargs): self.reset_ema_unets_all_one_device() return super().state_dict(*args, **kwargs) def load_state_dict(self, *args, **kwargs): self.reset_ema_unets_all_one_device() return super().load_state_dict(*args, **kwargs) # encoding text functions def encode_text(self, text, **kwargs): return self.imagen.encode_text(text, **kwargs) # forwarding functions and gradient step updates def update(self, unet_number = None): unet_number = self.validate_unet_number(unet_number) self.validate_and_set_unet_being_trained(unet_number) self.set_accelerator_scaler(unet_number) index = unet_number - 1 unet = self.unet_being_trained optimizer = getattr(self, f'optim{index}') scaler = getattr(self, f'scaler{index}') scheduler = getattr(self, f'scheduler{index}') warmup_scheduler = getattr(self, f'warmup{index}') # set the grad scaler on the accelerator, since we are managing one per u-net if exists(self.max_grad_norm): self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) optimizer.step() optimizer.zero_grad() if self.use_ema: ema_unet = self.get_ema_unet(unet_number) ema_unet.update() # scheduler, if needed maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() with maybe_warmup_context: if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs scheduler.step() self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) if not exists(self.checkpoint_path): return total_steps = int(self.steps.sum().item()) if total_steps % self.checkpoint_every: return self.save_to_checkpoint_folder() @torch.no_grad() @cast_torch_tensor @imagen_sample_in_chunks def sample(self, *args, **kwargs): context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets self.print_untrained_unets() if not self.is_main: kwargs['use_tqdm'] = False with context(): output = self.imagen.sample(*args, device = self.device, **kwargs) return output @partial(cast_torch_tensor, cast_fp16 = True) def forward( self, *args, unet_number = None, max_batch_size = None, **kwargs ): unet_number = self.validate_unet_number(unet_number) self.validate_and_set_unet_being_trained(unet_number) self.set_accelerator_scaler(unet_number) assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' total_loss = 0. # + for debug if self.CKeys['Debug_TrainerPack']==1: print("In Trainer:Forward, check inputs:") print('args: ', len(args)) print('args in:',args[0].shape) print('kwargs: ', kwargs.keys()) for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): # + for debug if self.CKeys['Debug_TrainerPack']==1: print("after chunks,...") print('chun_frac: ', chunk_size_frac) print('chun_args: ', chunked_args) print('chun_kwargs: ', chunked_kwargs) with self.accelerator.autocast(): loss = self.model( *chunked_args, unet_number = unet_number, **chunked_kwargs ) loss = loss * chunk_size_frac # + for debug if self.CKeys['Debug_TrainerPack']==1: print('part chun loss: ', loss) total_loss += loss#.item() if self.training: self.accelerator.backward(loss) return total_loss class ImagenTrainer_Old(nn.Module): locked = False def __init__( self, #imagen = None, model = None, imagen_checkpoint_path = None, use_ema = True, lr = 1e-4, eps = 1e-8, beta1 = 0.9, beta2 = 0.99, max_grad_norm = None, group_wd_params = True, warmup_steps = None, cosine_decay_max_steps = None, only_train_unet_number = None, fp16 = False, precision = None, split_batches = True, dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), verbose = True, split_valid_fraction = 0.025, split_valid_from_train = False, split_random_seed = 42, checkpoint_path = None, checkpoint_every = None, checkpoint_fs = None, fs_kwargs: dict = None, max_checkpoints_keep = 20, **kwargs ): super().__init__() assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' # determine filesystem, using fsspec, for saving to local filesystem or cloud self.fs = checkpoint_fs if not exists(self.fs): fs_kwargs = default(fs_kwargs, {}) self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) # # ----------------------------------- # # from MJB # assert isinstance(model.imagen, (ProteinDesigner_B)) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) self.imagen = model.imagen self.model=model self.is_elucidated = self.model.is_elucidated # create accelerator instance accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') self.accelerator = Accelerator(**{ 'split_batches': split_batches, 'mixed_precision': accelerator_mixed_precision, 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] , **accelerate_kwargs}) ImagenTrainer.locked = self.is_distributed # cast data to fp16 at training time if needed self.cast_half_at_training = accelerator_mixed_precision == 'fp16' # grad scaler must be managed outside of accelerator grad_scaler_enabled = fp16 self.num_unets = len(self.imagen.unets) self.use_ema = use_ema and self.is_main self.ema_unets = nn.ModuleList([]) # keep track of what unet is being trained on # only going to allow 1 unet training at a time self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on # data related functions self.train_dl_iter = None self.train_dl = None self.valid_dl_iter = None self.valid_dl = None self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names # auto splitting validation from training, if dataset is passed in self.split_valid_from_train = split_valid_from_train assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' self.split_valid_fraction = split_valid_fraction self.split_random_seed = split_random_seed # be able to finely customize learning rate, weight decay # per unet lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): optimizer = Adam( unet.parameters(), lr = unet_lr, eps = unet_eps, betas = (beta1, beta2), **kwargs ) if self.use_ema: self.ema_unets.append(EMA(unet, **ema_kwargs)) scaler = GradScaler(enabled = grad_scaler_enabled) scheduler = warmup_scheduler = None if exists(unet_cosine_decay_max_steps): scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) if exists(unet_warmup_steps): warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if not exists(scheduler): scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) # set on object setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers setattr(self, f'scaler{ind}', scaler) setattr(self, f'scheduler{ind}', scheduler) setattr(self, f'warmup{ind}', warmup_scheduler) # gradient clipping if needed self.max_grad_norm = max_grad_norm # step tracker and misc self.register_buffer('steps', torch.tensor([0] * self.num_unets)) self.verbose = verbose # automatic set devices based on what accelerator decided self.imagen.to(self.device) self.to(self.device) # checkpointing assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) self.checkpoint_path = checkpoint_path self.checkpoint_every = checkpoint_every self.max_checkpoints_keep = max_checkpoints_keep self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main if exists(checkpoint_path) and self.can_checkpoint: bucket = url_to_bucket(checkpoint_path) if not self.fs.exists(bucket): self.fs.mkdir(bucket) self.load_from_checkpoint_folder() # only allowing training for unet self.only_train_unet_number = only_train_unet_number self.validate_and_set_unet_being_trained(only_train_unet_number) # computed values @property def device(self): return self.accelerator.device @property def is_distributed(self): return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) @property def is_main(self): return self.accelerator.is_main_process @property def is_local_main(self): return self.accelerator.is_local_main_process @property def unwrapped_unet(self): return self.accelerator.unwrap_model(self.unet_being_trained) # optimizer helper functions def get_lr(self, unet_number): self.validate_unet_number(unet_number) unet_index = unet_number - 1 optim = getattr(self, f'optim{unet_index}') return optim.param_groups[0]['lr'] # function for allowing only one unet from being trained at a time def validate_and_set_unet_being_trained(self, unet_number = None): if exists(unet_number): self.validate_unet_number(unet_number) assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' self.only_train_unet_number = unet_number self.imagen.only_train_unet_number = unet_number if not exists(unet_number): return self.wrap_unet(unet_number) def wrap_unet(self, unet_number): if hasattr(self, 'one_unet_wrapped'): return unet = self.imagen.get_unet(unet_number) self.unet_being_trained = self.accelerator.prepare(unet) unet_index = unet_number - 1 optimizer = getattr(self, f'optim{unet_index}') scheduler = getattr(self, f'scheduler{unet_index}') optimizer = self.accelerator.prepare(optimizer) if exists(scheduler): scheduler = self.accelerator.prepare(scheduler) setattr(self, f'optim{unet_index}', optimizer) setattr(self, f'scheduler{unet_index}', scheduler) self.one_unet_wrapped = True # hacking accelerator due to not having separate gradscaler per optimizer def set_accelerator_scaler(self, unet_number): unet_number = self.validate_unet_number(unet_number) scaler = getattr(self, f'scaler{unet_number - 1}') self.accelerator.scaler = scaler for optimizer in self.accelerator._optimizers: optimizer.scaler = scaler # helper print def print(self, msg): if not self.is_main: return if not self.verbose: return return self.accelerator.print(msg) # validating the unet number def validate_unet_number(self, unet_number = None): if self.num_unets == 1: unet_number = default(unet_number, 1) assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' return unet_number # number of training steps taken def num_steps_taken(self, unet_number = None): if self.num_unets == 1: unet_number = default(unet_number, 1) return self.steps[unet_number - 1].item() def print_untrained_unets(self): print_final_error = False for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): if steps > 0 or isinstance(unet, NullUnet): continue self.print(f'unet {ind + 1} has not been trained') print_final_error = True if print_final_error: self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') # data related functions def add_train_dataloader(self, dl = None): if not exists(dl): return assert not exists(self.train_dl), 'training dataloader was already added' self.train_dl = self.accelerator.prepare(dl) def add_valid_dataloader(self, dl): if not exists(dl): return assert not exists(self.valid_dl), 'validation dataloader was already added' self.valid_dl = self.accelerator.prepare(dl) def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): if not exists(ds): return assert not exists(self.train_dl), 'training dataloader was already added' valid_ds = None if self.split_valid_from_train: train_size = int((1 - self.split_valid_fraction) * len(ds)) valid_size = len(ds) - train_size ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) self.train_dl = self.accelerator.prepare(dl) if not self.split_valid_from_train: return self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): if not exists(ds): return assert not exists(self.valid_dl), 'validation dataloader was already added' dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) self.valid_dl = self.accelerator.prepare(dl) def create_train_iter(self): assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' if exists(self.train_dl_iter): return self.train_dl_iter = cycle(self.train_dl) def create_valid_iter(self): assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' if exists(self.valid_dl_iter): return self.valid_dl_iter = cycle(self.valid_dl) def train_step(self, unet_number = None, **kwargs): self.create_train_iter() loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) self.update(unet_number = unet_number) return loss @torch.no_grad() @eval_decorator def valid_step(self, **kwargs): self.create_valid_iter() context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext with context(): loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) return loss def step_with_dl_iter(self, dl_iter, **kwargs): dl_tuple_output = cast_tuple(next(dl_iter)) model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) loss = self.forward(**{**kwargs, **model_input}) return loss # checkpointing functions @property def all_checkpoints_sorted(self): glob_pattern = os.path.join(self.checkpoint_path, '*.pt') checkpoints = self.fs.glob(glob_pattern) sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) return sorted_checkpoints def load_from_checkpoint_folder(self, last_total_steps = -1): if last_total_steps != -1: filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') self.load(filepath) return sorted_checkpoints = self.all_checkpoints_sorted if len(sorted_checkpoints) == 0: self.print(f'no checkpoints found to load from at {self.checkpoint_path}') return last_checkpoint = sorted_checkpoints[0] self.load(last_checkpoint) def save_to_checkpoint_folder(self): self.accelerator.wait_for_everyone() if not self.can_checkpoint: return total_steps = int(self.steps.sum().item()) filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') self.save(filepath) if self.max_checkpoints_keep <= 0: return sorted_checkpoints = self.all_checkpoints_sorted checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] for checkpoint in checkpoints_to_discard: self.fs.rm(checkpoint) # saving and loading functions def save( self, path, overwrite = True, without_optim_and_sched = False, **kwargs ): self.accelerator.wait_for_everyone() if not self.can_checkpoint: return fs = self.fs assert not (fs.exists(path) and not overwrite) self.reset_ema_unets_all_one_device() save_obj = dict( model = self.imagen.state_dict(), version = __version__, steps = self.steps.cpu(), **kwargs ) save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() for ind in save_optim_and_sched_iter: scaler_key = f'scaler{ind}' optimizer_key = f'optim{ind}' scheduler_key = f'scheduler{ind}' warmup_scheduler_key = f'warmup{ind}' scaler = getattr(self, scaler_key) optimizer = getattr(self, optimizer_key) scheduler = getattr(self, scheduler_key) warmup_scheduler = getattr(self, warmup_scheduler_key) if exists(scheduler): save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} if exists(warmup_scheduler): save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} if self.use_ema: save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} # determine if imagen config is available if hasattr(self.imagen, '_config'): self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""') save_obj = { **save_obj, 'imagen_type': 'elucidated' if self.is_elucidated else 'original', 'imagen_params': self.imagen._config } #save to path with fs.open(path, 'wb') as f: torch.save(save_obj, f) self.print(f'checkpoint saved to {path}') def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): fs = self.fs if noop_if_not_exist and not fs.exists(path): self.print(f'trainer checkpoint not found at {str(path)}') return assert fs.exists(path), f'{path} does not exist' self.reset_ema_unets_all_one_device() # to avoid extra GPU memory usage in main process when using Accelerate with fs.open(path) as f: loaded_obj = torch.load(f, map_location='cpu') if version.parse(__version__) != version.parse(loaded_obj['version']): self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') try: self.imagen.load_state_dict(loaded_obj['model'], strict = strict) except RuntimeError: print("Failed loading state dict. Trying partial load") self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), loaded_obj['model'])) if only_model: return loaded_obj self.steps.copy_(loaded_obj['steps']) for ind in range(0, self.num_unets): scaler_key = f'scaler{ind}' optimizer_key = f'optim{ind}' scheduler_key = f'scheduler{ind}' warmup_scheduler_key = f'warmup{ind}' scaler = getattr(self, scaler_key) optimizer = getattr(self, optimizer_key) scheduler = getattr(self, scheduler_key) warmup_scheduler = getattr(self, warmup_scheduler_key) if exists(scheduler) and scheduler_key in loaded_obj: scheduler.load_state_dict(loaded_obj[scheduler_key]) if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) if exists(optimizer): try: optimizer.load_state_dict(loaded_obj[optimizer_key]) scaler.load_state_dict(loaded_obj[scaler_key]) except: self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') if self.use_ema: assert 'ema' in loaded_obj try: self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) except RuntimeError: print("Failed loading state dict. Trying partial load") self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), loaded_obj['ema'])) self.print(f'checkpoint loaded from {path}') return loaded_obj # managing ema unets and their devices @property def unets(self): return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) def get_ema_unet(self, unet_number = None): if not self.use_ema: return unet_number = self.validate_unet_number(unet_number) index = unet_number - 1 if isinstance(self.unets, nn.ModuleList): unets_list = [unet for unet in self.ema_unets] delattr(self, 'ema_unets') self.ema_unets = unets_list if index != self.ema_unet_being_trained_index: for unet_index, unet in enumerate(self.ema_unets): unet.to(self.device if unet_index == index else 'cpu') self.ema_unet_being_trained_index = index return self.ema_unets[index] def reset_ema_unets_all_one_device(self, device = None): if not self.use_ema: return device = default(device, self.device) self.ema_unets = nn.ModuleList([*self.ema_unets]) self.ema_unets.to(device) self.ema_unet_being_trained_index = -1 @torch.no_grad() @contextmanager def use_ema_unets(self): if not self.use_ema: output = yield return output self.reset_ema_unets_all_one_device() self.imagen.reset_unets_all_one_device() self.unets.eval() trainable_unets = self.imagen.unets self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling output = yield self.imagen.unets = trainable_unets # restore original training unets # cast the ema_model unets back to original device for ema in self.ema_unets: ema.restore_ema_model_device() return output def print_unet_devices(self): self.print('unet devices:') for i, unet in enumerate(self.imagen.unets): device = next(unet.parameters()).device self.print(f'\tunet {i}: {device}') if not self.use_ema: return self.print('\nema unet devices:') for i, ema_unet in enumerate(self.ema_unets): device = next(ema_unet.parameters()).device self.print(f'\tema unet {i}: {device}') # overriding state dict functions def state_dict(self, *args, **kwargs): self.reset_ema_unets_all_one_device() return super().state_dict(*args, **kwargs) def load_state_dict(self, *args, **kwargs): self.reset_ema_unets_all_one_device() return super().load_state_dict(*args, **kwargs) # encoding text functions def encode_text(self, text, **kwargs): return self.imagen.encode_text(text, **kwargs) # forwarding functions and gradient step updates def update(self, unet_number = None): unet_number = self.validate_unet_number(unet_number) self.validate_and_set_unet_being_trained(unet_number) self.set_accelerator_scaler(unet_number) index = unet_number - 1 unet = self.unet_being_trained optimizer = getattr(self, f'optim{index}') scaler = getattr(self, f'scaler{index}') scheduler = getattr(self, f'scheduler{index}') warmup_scheduler = getattr(self, f'warmup{index}') # set the grad scaler on the accelerator, since we are managing one per u-net if exists(self.max_grad_norm): self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) optimizer.step() optimizer.zero_grad() if self.use_ema: ema_unet = self.get_ema_unet(unet_number) ema_unet.update() # scheduler, if needed maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() with maybe_warmup_context: if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs scheduler.step() self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) if not exists(self.checkpoint_path): return total_steps = int(self.steps.sum().item()) if total_steps % self.checkpoint_every: return self.save_to_checkpoint_folder() @torch.no_grad() @cast_torch_tensor @imagen_sample_in_chunks def sample(self, *args, **kwargs): context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets self.print_untrained_unets() if not self.is_main: kwargs['use_tqdm'] = False with context(): output = self.imagen.sample(*args, device = self.device, **kwargs) return output @partial(cast_torch_tensor, cast_fp16 = True) def forward( self, *args, unet_number = None, max_batch_size = None, **kwargs ): unet_number = self.validate_unet_number(unet_number) self.validate_and_set_unet_being_trained(unet_number) self.set_accelerator_scaler(unet_number) assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' total_loss = 0. # + for debug print('args: ', len(args)) print('args in:',args[0].shape) print('kwargs: ', kwargs.keys()) for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): # + for debug print('chun_frac: ', chunk_size_frac) print('chun_args: ', chunked_args) print('chun_kwargs: ', chunked_kwargs) with self.accelerator.autocast(): loss = self.model(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = loss * chunk_size_frac print('loss: ', loss) total_loss += loss#.item() if self.training: self.accelerator.backward(loss) print('I am here') return total_loss def write_fasta (sequence, filename_out): with open (filename_out, mode ='w') as f: f.write (f'>{filename_out}\n') f.write (f'{sequence}') # def sample_sequence_FromModelB ( model, X=None, #this is the target conventionally when using text embd flag=0, cond_scales=1., foldproteins=False, X_string=None, x_data=None, skip_steps=0, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # ++++++++++++++++++++++++ ynormfac=1, train_unet_number=1, tokenizer_X=None, Xnormfac=1., max_length=1., prefix=None, tokenizer_y=None, ): steps=0 e=flag #num_samples = min (num_samples,y_train_batch.shape[0] ) if X!=None: print (f"Producing {len(X)} samples...from text conditioning X...") lenn_val=len(X) if X_string!=None: lenn_val=len(X_string) print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") if x_data!=None: print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") lenn_val=len(x_data) print (x_data) print ('Device: ', device) for iisample in range (lenn_val): X_cond=None if X_string==None and X != None: #only do if X provided X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) if X_string !=None: X = tokenizer_X.texts_to_sequences(X_string[iisample]) X= sequence.pad_sequences(X, maxlen=max_length, padding='post', truncating='post') X=np.array(X) X_cond=torch.from_numpy(X).float()/Xnormfac print ('Tokenized and processed: ', X_cond) print ("X_cond=", X_cond) result=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=x_data, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images,device=device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, ) result=torch.round(result*ynormfac) plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted') #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') plt.legend() outname = prefix+ f"sampled_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") plt.savefig(outname, dpi=200) plt.show () to_rev=result[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() print (to_rev.shape) y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") ### reverse second structure input.... if X_cond != None: X_cond=torch.round(X_cond*Xnormfac) to_rev=X_cond[:,:] to_rev=to_rev.long().cpu().detach().numpy() print (to_rev.shape) X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") if x_data !=None: X_data_reversed=x_data #is already in sequence fromat.. print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample]) if foldproteins: if X_cond != None: xbc=X_cond[iisample,:].cpu().detach().numpy() out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' if x_data !=None: #xbc=x_data[iisample] out_nam=x_data[iisample] tempname='temp' pdb_file=foldandsavePDB ( sequence=y_data_reversed[0], filename_out=tempname, num_cycle=num_cycle, flag=flag, # +++++++++++++++++++ prefix=prefix ) out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta' write_fasta (y_data_reversed[0], out_nam_fasta) out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # print('Debug 1: out: ', out_nam) # print('Debug 2: in: ', pdb_file) shutil.copy (pdb_file, out_nam) #source, dest # cmd_line = 'cp ' + pdb_file + ' ' + out_nam # print(cmd_line) # os.popen(cmd_line) # print('Debug 3') pdb_file=out_nam print (f"Properly named PDB file produced: {pdb_file}") #flag=1000 view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) view.show() steps=steps+1 return pdb_file def sample_loop_FromModelB (model, train_loader, cond_scales=[7.5], #list of cond scales - each sampled... num_samples=2, #how many samples produced every time tested..... timesteps=100, flag=0,foldproteins=False, use_text_embedd=True,skip_steps=0, # +++++++++++++++++++ train_unet_number=1, ynormfac=1, prefix=None, tokenizer_y=None, Xnormfac=1, tokenizer_X=None, ): steps=0 e=flag for item in train_loader: X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) GT=y_train_batch.cpu().detach() GT= GT.unsqueeze(1) num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") print ('X_train_batch shape: ', X_train_batch.shape) for iisample in range (len (cond_scales)): if use_text_embedd: result=model.sample (x= X_train_batch,stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps) else: result=model.sample (x= None, x_data_tokenized= X_train_batch, stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample],device=device,skip_steps=skip_steps) result=torch.round(result*ynormfac) GT=torch.round (GT*ynormfac) for samples in range (num_samples): print ("sample ", samples, "out of ", num_samples) plt.plot (result[samples,0,:].cpu().detach().numpy(),label= f'Predicted') plt.plot (GT[samples,0,:],label= f'GT {0}') plt.legend() outname = prefix+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" plt.savefig(outname, dpi=200) plt.show () #reverse y sequence to_rev=result[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") #reverse GT_y sequence to_rev=GT[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") ### reverse second structure input.... to_rev=torch.round (X_train_batch[:,:]*Xnormfac) to_rev=to_rev.long().cpu().detach().numpy() X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, predicted sequence: ", y_data_reversed[samples]) print (f"Ground truth: {GT_y_data_reversed[samples]}") if foldproteins: xbc=X_train_batch[samples,:].cpu().detach().numpy() out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) tempname='temp' pdb_file=foldandsavePDB ( sequence=y_data_reversed[samples], filename_out=tempname, num_cycle=16, flag=flag, # +++++++++++++++++++ prefix=prefix ) #out_nam=f'{prefix}{out_nam}.pdb' out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' print (f'Original PDB: {pdb_file} OUT: {out_nam}') shutil.copy (pdb_file, out_nam) #source, dest pdb_file=out_nam print (f"Properly named PDB file produced: {pdb_file}") view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) view.show() steps=steps+1 if steps>num_samples: break # ++ def sample_sequence_omegafold_pLM_ModelB ( model, X=None, #this is the target conventionally when using text embd flag=0, cond_scales=1., foldproteins=False, X_string=None, x_data=None, skip_steps=0, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # ++++++++++++++++++++++++ ynormfac=1, train_unet_number=1, tokenizer_X=None, Xnormfac=1., max_length=1., prefix=None, tokenizer_y=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True # ++ pLM_Model=None, # pLM_Model, pLM_Model_Name=None, # pLM_Model_Name, image_channels=None, # image_channels, pLM_alphabet=None, # esm_alphabet, ): # steps=0 # e=flag #num_samples = min (num_samples,y_train_batch.shape[0] ) if X!=None: print (f"Producing {len(X)} samples...from text conditioning X...") lenn_val=len(X) if X_string!=None: lenn_val=len(X_string) print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") if x_data!=None: print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") lenn_val=len(x_data) print (x_data) print ('Device: ', device) pdb_file_list=[] fasta_file_list=[] # + for debug print('tot ', lenn_val) for iisample in range (lenn_val): print("Working on ", iisample) X_cond=None # this is for text-conditioning if X_string==None and X != None: #only do if X provided X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) if X_string !=None: XX = tokenizer_X.texts_to_sequences(X_string[iisample]) XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') XX=np.array(XX) X_cond=torch.from_numpy(XX).float()/Xnormfac print ('Tokenized and processed: ', X_cond) print ("X_cond=", X_cond) # # -- # result=model.sample ( # x=X_cond, # stop_at_unet_number=train_unet_number , # cond_scale=cond_scales , # x_data=x_data[iisample], # # ++ # x_data_tokenized= # skip_steps=skip_steps, # inpaint_images = inpaint_images, # inpaint_masks = inpaint_masks, # inpaint_resample_times = inpaint_resample_times, # init_images = init_images,device=device, # # ++++++++++++++++++++++++++ # tokenizer_X=tokenizer_X, # Xnormfac=Xnormfac, # max_length=max_length, # ) # ++ # use cond_image as the conditioning, via x_data_tokenized channel # ----------------------------------------------------------------- # for below, two branches are all for cond_img, not for text_cond if tokenizer_X!=None: # for SecStr+ModelB result_embedding=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=x_data[iisample], # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim # ++ x_data_tokenized=None, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images,device=device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, ) else: # for ForcPath+ModelB: # for model.sample() here using x_data_tokenized channel # x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac) x_data_tokenized=x_data_tokenized.to(torch.float) # here, only one input list is read in x_data_tokenized=x_data_tokenized.unsqueeze(0) # [batch=1, seq_len] # leave channel expansion for the self.sample() to handle # + for debug: if CKeys['Debug_TrainerPack']==3: print("x_data_tokenized dim: ", x_data_tokenized.shape) print("x_data_tokenized dtype: ", x_data_tokenized.dtype) print("test x_data_tokenized!=None: ", x_data_tokenized!=None) result_embedding=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=None, # ++ x_data_tokenized=x_data_tokenized, # skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images,device=device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, ) # # -- # result=torch.round(result*ynormfac) # (batch=1, channel=1, seq_len) # ++ for pLM # full record # result_embedding as image.dim: [batch, channels, seq_len] # result_tokens.dim: [batch, seq_len] if image_channels==33: result_tokens,result_logits = convert_into_tokens_using_prob( result_embedding, pLM_Model_Name, ) else: result_tokens,result_logits = convert_into_tokens( pLM_Model, result_embedding, pLM_Model_Name, ) # +++++++++++++++++++++++++++++++++ result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] # + for debug print('result dim: ', result.shape) # plot sequence token code: esm (33 tokens) fig=plt.figure() plt.plot ( result[0,0,:].cpu().detach().numpy(), label= f'Predicted' ) #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') plt.legend() outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") if IF_showfig==1: plt.show () else: plt.savefig(outname, dpi=200) plt.close() # # # -- # to_rev=result[:,0,:] # to_rev=to_rev.long().cpu().detach().numpy() # print (to_rev.shape) # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) # # for iii in range (len(y_data_reversed)): # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") # # ++: for model A: no mask is provided from input # # reverse the PREDICTED y into a foldable sequence # # save this block for Model A # to_rev=result[:,0,:] # token (batch,seq_len) # y_data_reversed=decode_many_ems_token_rec_for_folding( # to_rev, # result_logits, # pLM_alphabet, # pLM_Model, # ) # if CKeys['Debug_TrainerPack']==3: # print("on foldable result: ", to_rev[0]) # print("on result_logits: ", result_logits[0]) # a = decode_one_ems_token_rec_for_folding( # to_rev[0], # result_logits[0], # pLM_alphabet, # pLM_Model, # ) # print('One resu: ', a) # print("on y_data_reversed: ", y_data_reversed[0]) # print("y_data_reversed.type", y_data_reversed.dtype) # # ++: for model B: using mask from the input # extract the mask/seq_len from input if possible if tokenizer_X!=None: # for SecStr+ModelB result_mask = read_mask_from_input( tokenized_data=None, mask_value=None, seq_data=x_data[iisample], max_seq_length=max_length, ) else: # for ForcPath+ModelB result_mask = read_mask_from_input( tokenized_data=x_data_tokenized, # None, mask_value=0, # None, seq_data=None, # x_data[iisample], max_seq_length=None, # max_length, ) to_rev=result[:,0,:] # token (batch,seq_len) if CKeys['Debug_TrainerPack']==3: print("on foldable result: ", to_rev[0]) print("on result_logits: ", result_logits[0]) print("on mask: ", result_mask[0]) a = decode_one_ems_token_rec_for_folding_with_mask( to_rev[0], result_logits[0], pLM_alphabet, pLM_Model, result_mask[0], ) print('One resu: ', a) y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( to_rev, result_logits, pLM_alphabet, pLM_Model, result_mask, ) if CKeys['Debug_TrainerPack']==3: print("on y_data_reversed[0]: ", y_data_reversed[0]) ### reverse second structure input.... if X_cond != None: X_cond=torch.round(X_cond*Xnormfac) to_rev=X_cond[:,:] to_rev=to_rev.long().cpu().detach().numpy() print (to_rev.shape) X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") if x_data !=None: X_data_reversed=x_data #is already in sequence fromat.. # print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence", y_data_reversed[iisample]) print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed) # + for debug print("================================================") print("foldproteins: ", foldproteins) if not foldproteins: pdb_file=None else: if X_cond != None: xbc=X_cond[iisample,:].cpu().detach().numpy() out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' if x_data !=None: #xbc=x_data[iisample] # ---------------------------------- # this one can be too long for a name out_nam=x_data[iisample] # ++++++++++++++++++++++++++++++++++ # out_nam=iisample tempname='temp' pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( sequence=y_data_reversed[0], filename_out=tempname, num_cycle=num_cycle, flag=flag, # +++++++++++++++++++ # prefix=prefix, prefix=sample_dir, ) # out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta' # ------------------------------------------ # this one can be too long for a name # out_nam_fasta=f'{sample_dir}{out_nam}_{flag}_{e}_{iisample}.fasta' # ++++++++++++++++++++++++++++++++++++++++++ out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' write_fasta (y_data_reversed[0], out_nam_fasta) # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # ------------------------------------------- # this one can be too long for a name # However, the input X is recorded in the code # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb' # +++++++++++++++++++++++++++++++++++++++++++ out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' # print('Debug 1: out: ', out_nam) # print('Debug 2: in: ', pdb_file) shutil.copy (pdb_file, out_nam) #source, dest shutil.copy (fasta_file, out_nam_fasta) # cmd_line = 'cp ' + pdb_file + ' ' + out_nam # print(cmd_line) # os.popen(cmd_line) # print('Debug 3') # clean the slade to avoid mistakenly using the previous fasta file os.remove (pdb_file) os.remove (fasta_file) pdb_file=out_nam fasta_file=out_nam_fasta pdb_file_list.append(pdb_file) fasta_file_list.append(fasta_file) print (f"Properly named PDB file produced: {pdb_file}") if IF_showfig==1: #flag=1000 view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() # steps=steps+1 return pdb_file_list, fasta_file_list # def sample_sequence_omegafold_ModelB ( model, X=None, #this is the target conventionally when using text embd flag=0, cond_scales=1., foldproteins=False, X_string=None, x_data=None, skip_steps=0, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # ++++++++++++++++++++++++ ynormfac=1, train_unet_number=1, tokenizer_X=None, Xnormfac=1., max_length=1., prefix=None, tokenizer_y=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True ): # steps=0 # e=flag #num_samples = min (num_samples,y_train_batch.shape[0] ) if X!=None: print (f"Producing {len(X)} samples...from text conditioning X...") lenn_val=len(X) if X_string!=None: lenn_val=len(X_string) print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") if x_data!=None: print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") lenn_val=len(x_data) print (x_data) print ('Device: ', device) # + for debug print('tot ', lenn_val) for iisample in range (lenn_val): print("Working on ", iisample) X_cond=None if X_string==None and X != None: #only do if X provided X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) if X_string !=None: XX = tokenizer_X.texts_to_sequences(X_string[iisample]) XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') XX=np.array(XX) X_cond=torch.from_numpy(XX).float()/Xnormfac print ('Tokenized and processed: ', X_cond) print ("X_cond=", X_cond) # # -- # result=model.sample ( # x=X_cond, # stop_at_unet_number=train_unet_number , # cond_scale=cond_scales , # x_data=x_data[iisample], # # ++ # x_data_tokenized= # skip_steps=skip_steps, # inpaint_images = inpaint_images, # inpaint_masks = inpaint_masks, # inpaint_resample_times = inpaint_resample_times, # init_images = init_images,device=device, # # ++++++++++++++++++++++++++ # tokenizer_X=tokenizer_X, # Xnormfac=Xnormfac, # max_length=max_length, # ) # ++ if tokenizer_X!=None: result=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=x_data[iisample], # ++ x_data_tokenized=None, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images,device=device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, ) else: x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac) x_data_tokenized=x_data_tokenized.to(torch.float) # + for debug: if CKeys['Debug_TrainerPack']==1: print("x_data_tokenized dim: ", x_data_tokenized.shape) print("x_data_tokenized dtype: ", x_data_tokenized.dtype) print("test: ", x_data_tokenized!=None) result=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=None, # ++ x_data_tokenized=x_data_tokenized, # skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images,device=device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, ) result=torch.round(result*ynormfac) # + for debug print('result dim: ', result.shape) fig=plt.figure() plt.plot ( result[0,0,:].cpu().detach().numpy(), label= f'Predicted' ) #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') plt.legend() outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") if IF_showfig==1: plt.show () else: plt.savefig(outname, dpi=200) plt.close() to_rev=result[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() print (to_rev.shape) y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") ### reverse second structure input.... if X_cond != None: X_cond=torch.round(X_cond*Xnormfac) to_rev=X_cond[:,:] to_rev=to_rev.long().cpu().detach().numpy() print (to_rev.shape) X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") if x_data !=None: X_data_reversed=x_data #is already in sequence fromat.. # print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence", y_data_reversed[iisample]) print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed) # + for debug print("================================================") print("foldproteins: ", foldproteins) if not foldproteins: pdb_file=None else: if X_cond != None: xbc=X_cond[iisample,:].cpu().detach().numpy() out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' if x_data !=None: #xbc=x_data[iisample] # ---------------------------------- # this one can be too long for a name out_nam=x_data[iisample] # ++++++++++++++++++++++++++++++++++ # out_nam=iisample tempname='temp' pdb_file=foldandsavePDB ( sequence=y_data_reversed[0], filename_out=tempname, num_cycle=num_cycle, flag=flag, # +++++++++++++++++++ # prefix=prefix, prefix=sample_dir, ) # out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta' # ------------------------------------------ # this one can be too long for a name # out_nam_fasta=f'{sample_dir}{out_nam}_{flag}_{e}_{iisample}.fasta' # ++++++++++++++++++++++++++++++++++++++++++ out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' write_fasta (y_data_reversed[0], out_nam_fasta) # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # ------------------------------------------- # this one can be too long for a name # However, the input X is recorded in the code # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb' # +++++++++++++++++++++++++++++++++++++++++++ out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' # print('Debug 1: out: ', out_nam) # print('Debug 2: in: ', pdb_file) shutil.copy (pdb_file, out_nam) #source, dest # cmd_line = 'cp ' + pdb_file + ' ' + out_nam # print(cmd_line) # os.popen(cmd_line) # print('Debug 3') pdb_file=out_nam print (f"Properly named PDB file produced: {pdb_file}") if IF_showfig==1: #flag=1000 view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() # steps=steps+1 return pdb_file # ++ for de novo input of ForcPath # ++ def sample_sequence_omegafold_pLM_ModelB_For_ForcPath ( model, X=None, #this is the target conventionally when using text embd flag=0, cond_scales=[1.], # 1., foldproteins=False, X_string=None, x_data=None, skip_steps=0, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # ++++++++++++++++++++++++ ynormfac=1, train_unet_number=1, tokenizer_X=None, Xnormfac=1., max_length=1., prefix=None, tokenizer_y=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True # ++ pLM_Model=None, # pLM_Model, pLM_Model_Name=None, # pLM_Model_Name, image_channels=None, # image_channels, pLM_alphabet=None, # esm_alphabet, ): # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # prepare input in different channels # X_cond=None # this is for text-conditioning if X_string==None and X != None: #only do if X provided print (f"Producing {len(X)} samples...from text conditioning X...") lenn_val=len(X) # shape of X: [[..],[..]]: double bracket X_cond=torch.Tensor(X).to(device) # -- # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) if X_string !=None: print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") lenn_val=len(X_string) # -- XX = tokenizer_X.texts_to_sequences(X_string[iisample]) # ++ XX = tokenizer_X.texts_to_sequences(X_string) XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') XX=np.array(XX) X_cond=torch.from_numpy(XX).float()/Xnormfac print ('Tokenized and processed: ', X_cond) if x_data!=None: print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") lenn_val=len(x_data) if tokenizer_X==None: # for ForcPath, # need to do Padding and Normalization # and then put into tokenized data channel x_data_tokenized=[] for ii in range(lenn_val): x_data_one_line=pad_a_np_arr(x_data[ii], 0.0, max_length) x_data_tokenized.append(x_data_one_line) x_data_tokenized=np.array(x_data_tokenized) x_data_tokenized=torch.from_numpy(x_data_tokenized/Xnormfac) else: # leave for SecStr case: TBA pass # print (x_data) # ++ for result_mask based on input: x_data or x_data_tokenized # ++: for model B: using mask from the input # extract the mask/seq_len from input if possible if tokenizer_X!=None: # for SecStr+ModelB result_mask = read_mask_from_input( tokenized_data=None, mask_value=None, seq_data=x_data, # x_data[iisample], max_seq_length=max_length, ) else: # for ForcPath+ModelB result_mask = read_mask_from_input( tokenized_data=x_data_tokenized, # None, mask_value=0, # None, seq_data=None, # x_data[iisample], max_seq_length=None, # max_length, ) print ("Input contents:") print ("cond_img condition: x_data=\n", x_data) print ("Text condition: X_cond=\n", X_cond) # store the results pdb_file_list=[] fasta_file_list=[] # loop over cond_scales for idx_cond, this_cond_scale in enumerate(cond_scales): print(f"Working on cond_scale {str(this_cond_scale)}") # do sampling # ----------------------------------------------------------------- # for below, two branches are all for cond_img, not for text_cond if tokenizer_X!=None: # for SecStr+ModelB, not test here result_embedding=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=this_cond_scale, # cond_scales , x_data=x_data, # x_data[iisample], # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim # ++ x_data_tokenized=None, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images,device=device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, ) else: # for ForcPath+ModelB: # for model.sample() here using x_data_tokenized channel x_data_tokenized=x_data_tokenized.to(torch.float) # shape [batch, max_seq_len] # leave channel expansion for the self.sample() to handle # + for debug: if CKeys['Debug_TrainerPack']==3: print("x_data_tokenized dim: ", x_data_tokenized.shape) print("x_data_tokenized dtype: ", x_data_tokenized.dtype) print("test x_data_tokenized!=None: ", x_data_tokenized!=None) result_embedding=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=this_cond_scale, # cond_scales , x_data=None, # ++ x_data_tokenized=x_data_tokenized, # skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images,device=device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, ) # handle the results: from embedding into AA # ++ for pLM if image_channels==33: # pass result_tokens,result_logits = convert_into_tokens_using_prob( result_embedding, pLM_Model_Name, ) else: # full record # result_embedding as image.dim: [batch, channels, seq_len] # result_tokens.dim: [batch, seq_len] result_tokens,result_logits = convert_into_tokens( pLM_Model, result_embedding, pLM_Model_Name, ) # +++++++++++++++++++++++++++++++++ result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] # + for debug print('result dim: ', result.shape) # plot sequence token code: esm (33 tokens), for one batch fig=plt.figure() for ii in range(lenn_val): plt.plot ( result[ii,0,:].cpu().detach().numpy(), label= f'Predicted for Input#{str(ii)}' ) #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') plt.legend() outname = sample_dir+ f"DenovoInputXs_CondScale_No{str(idx_cond)}_Val_{str(this_cond_scale)}_{e}_{steps}.jpg" #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") if IF_showfig==1: plt.show () else: plt.savefig(outname, dpi=200) plt.close() # translate result into AA to_rev=result[:,0,:] # token (batch,seq_len) if CKeys['Debug_TrainerPack']==3: print("on foldable result: ", to_rev[0]) print("on result_logits: ", result_logits[0]) print("on mask: ", result_mask[0]) a = decode_one_ems_token_rec_for_folding_with_mask( to_rev[0], result_logits[0], pLM_alphabet, pLM_Model, result_mask[0], ) print('One resu: ', a) y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( to_rev, result_logits, pLM_alphabet, pLM_Model, result_mask, ) if CKeys['Debug_TrainerPack']==3: print("on y_data_reversed[0]: ", y_data_reversed[0]) ### reverse second structure input.... if X_cond != None: X_cond=torch.round(X_cond*Xnormfac) to_rev=X_cond[:,:] to_rev=to_rev.long().cpu().detach().numpy() print (to_rev.shape) X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") if x_data !=None: # work for second structure input.... # work for ForcPath input... X_data_reversed=x_data #is already in sequence fromat.. # sections for each one result for iisample in range(lenn_val): print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample]) out_nam_fasta=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.fasta' write_fasta (y_data_reversed[iisample], out_nam_fasta) fasta_file_list.append(out_nam_fasta) # + for debug print("================================================") print("foldproteins: ", foldproteins) if not foldproteins: pdb_file=None else: if X_cond != None: # not maintained xbc=X_cond[iisample,:].cpu().detach().numpy() out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})+f'_{flag}_{steps}' if x_data !=None: pass # #xbc=x_data[iisample] # # ---------------------------------- # # this one can be too long for a name # out_nam=x_data[iisample] # # ++++++++++++++++++++++++++++++++++ # # # out_nam=iisample tempname='temp' pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( sequence=y_data_reversed[iisample], filename_out=tempname, num_cycle=num_cycle, flag=flag, # +++++++++++++++++++ # prefix=prefix, prefix=sample_dir, ) # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # ------------------------------------------- # this one can be too long for a name # However, the input X is recorded in the code # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb' # +++++++++++++++++++++++++++++++++++++++++++ out_nam=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.pdb' # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' # print('Debug 1: out: ', out_nam) # print('Debug 2: in: ', pdb_file) shutil.copy (pdb_file, out_nam) #source, dest # shutil.copy (fasta_file, out_nam_fasta) # cmd_line = 'cp ' + pdb_file + ' ' + out_nam # print(cmd_line) # os.popen(cmd_line) # print('Debug 3') # clean the slade to avoid mistakenly using the previous fasta file os.remove (pdb_file) os.remove (fasta_file) pdb_file=out_nam # fasta_file=out_nam_fasta pdb_file_list.append(pdb_file) # ++ write the input condtion as a reference: for ForcPath out_nam_inX=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}_input.txt' if torch.is_tensor(X_data_reversed[iisample]): # for safety, not used usually xbc=X_data_reversed[iisample].cpu().detach().numpy() else: xbc=X_data_reversed[iisample] if tokenizer_X==None: # for ForcPath case out_inX=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) else: # for SecStr case out_inX=xbc with open(out_nam_inX, "w") as inX_file: inX_file.write(out_inX) print (f"Properly named PDB file produced: {pdb_file}") if IF_showfig==1: #flag=1000 view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() return pdb_file_list, fasta_file_list # ++ def sample_sequence_pLM_ModelB_For_ForcPath_Predictor ( model, X=None, #this is the target conventionally when using text embd flag=0, cond_scales=[1.], # 1., foldproteins=False, X_string=None, x_data=None, skip_steps=0, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # ++++++++++++++++++++++++ ynormfac=1, train_unet_number=1, tokenizer_X=None, Xnormfac=1., max_length=1., prefix=None, tokenizer_y=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True # ++ pLM_Model=None, # pLM_Model, pLM_Model_Name=None, # pLM_Model_Name, image_channels=None, # image_channels, pLM_alphabet=None, # esm_alphabet, # ++ esm_layer=None, ): # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # input: a list of AA sequence in string format # output: ForcPath prediction # # 1. decide input channel: text_cond or img_cond X_cond=None # this is for text-conditioning if X_string==None and X != None: #only do if X provided print (f"Producing {len(X)} samples...from text conditioning X...") lenn_val=len(X) # shape of X: [[..],[..]]: double bracket X_cond=torch.Tensor(X).to(device) # -- # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) if X_string !=None: print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") lenn_val=len(X_string) # -- XX = tokenizer_X.texts_to_sequences(X_string[iisample]) # ++ XX = tokenizer_X.texts_to_sequences(X_string) XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') XX=np.array(XX) X_cond=torch.from_numpy(XX).float()/Xnormfac print ('Tokenized and processed: ', X_cond) if x_data!=None: # this is for img_conditioning channel # Will use this channel print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") lenn_val=len(x_data) seq_len_list=[] for this_AA in x_data: seq_len_list.append(len(this_AA)) print ("Input contents:") print ("cond_img condition: x_data=\n", x_data) print ("Text condition: X_cond=\n", X_cond) # 2. perform sampling # loop over cond_scales resu_prediction={} for idx_cond, this_cond_scale in enumerate(cond_scales): print(f"Working on cond_scale {str(this_cond_scale)}") # leave the translation from seq to tokenized # in the model.sample function using x_data channel # Need to pass on the esm model part or tokenizer_X # result_embedding=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=this_cond_scale, # cond_scales , x_data=x_data, # x_data[iisample], # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim # ++ x_data_tokenized=None, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images,device=device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, # ++ for esm pLM_Model=pLM_Model, pLM_alphabet=pLM_alphabet, esm_layer=esm_layer, pLM_Model_Name=pLM_Model_Name, # image_channels=image_channels, ) # convert into prediction # 3. translate prediction into something meaningful # consider channel average and masking # average across channels result_embedding=torch.mean(result_embedding, 1) # (batch, seq_len) # read mask from input: X_train_batch_picked (batch, seq_len) # result_mask looks like, 0,1,1,...,1,0,0 # will fill 0th component be zero result_mask = read_mask_from_input( tokenized_data=None, # X_train_batch[:num_samples], mask_value=0.0, seq_data=x_data, # None, max_seq_length=max_length, # None, ) # apply mask to result: keep true and zero all false # this also make sure 0th components are zero ACCIDENTLY result = result_embedding.cpu()*result_mask # (batch, seq_len) # result = result.cpu() y_data_reversed = result*ynormfac # 4. translate the results into a list prediction_list = [] for ii in range(len(x_data)): prediction_list.append( y_data_reversed[ii, :seq_len_list[ii]+1] ) if CKeys['Debug_TrainerPack']==3: print("check prediction dim:") print(f"model output: ", y_data_reversed[0]) print(f"prediction output: ", prediction_list[0]) # # store the results resu_prediction[str(this_cond_scale)]=prediction_list return resu_prediction,seq_len_list # # -------------------------------------------------------------------------- # # prepare input in different channels # # # X_cond=None # this is for text-conditioning # if X_string==None and X != None: #only do if X provided # print (f"Producing {len(X)} samples...from text conditioning X...") # lenn_val=len(X) # # shape of X: [[..],[..]]: double bracket # X_cond=torch.Tensor(X).to(device) # # -- # # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) # if X_string !=None: # print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") # lenn_val=len(X_string) # # -- # XX = tokenizer_X.texts_to_sequences(X_string[iisample]) # # ++ # XX = tokenizer_X.texts_to_sequences(X_string) # XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') # XX=np.array(XX) # X_cond=torch.from_numpy(XX).float()/Xnormfac # print ('Tokenized and processed: ', X_cond) # if x_data!=None: # print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") # lenn_val=len(x_data) # if tokenizer_X==None: # for ForcPath, # # need to do Padding and Normalization # # and then put into tokenized data channel # x_data_tokenized=[] # for ii in range(lenn_val): # x_data_one_line=pad_a_np_arr(x_data[ii], 0.0, max_length) # x_data_tokenized.append(x_data_one_line) # x_data_tokenized=np.array(x_data_tokenized) # x_data_tokenized=torch.from_numpy(x_data_tokenized/Xnormfac) # else: # # leave for SecStr case: TBA # pass # # print (x_data) # # ++ for result_mask based on input: x_data or x_data_tokenized # # ++: for model B: using mask from the input # # extract the mask/seq_len from input if possible # if tokenizer_X!=None: # # for SecStr+ModelB # result_mask = read_mask_from_input( # tokenized_data=None, # mask_value=None, # seq_data=x_data, # x_data[iisample], # max_seq_length=max_length, # ) # else: # # for ForcPath+ModelB # result_mask = read_mask_from_input( # tokenized_data=x_data_tokenized, # None, # mask_value=0, # None, # seq_data=None, # x_data[iisample], # max_seq_length=None, # max_length, # ) # print ("Input contents:") # print ("cond_img condition: x_data=\n", x_data) # print ("Text condition: X_cond=\n", X_cond) # # store the results # pdb_file_list=[] # fasta_file_list=[] # # loop over cond_scales # for idx_cond, this_cond_scale in enumerate(cond_scales): # print(f"Working on cond_scale {str(this_cond_scale)}") # # do sampling # # ----------------------------------------------------------------- # # for below, two branches are all for cond_img, not for text_cond # if tokenizer_X!=None: # # for SecStr+ModelB, not test here # result_embedding=model.sample ( # x=X_cond, # stop_at_unet_number=train_unet_number , # cond_scale=this_cond_scale, # cond_scales , # x_data=x_data, # x_data[iisample], # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim # # ++ # x_data_tokenized=None, # skip_steps=skip_steps, # inpaint_images = inpaint_images, # inpaint_masks = inpaint_masks, # inpaint_resample_times = inpaint_resample_times, # init_images = init_images,device=device, # # ++++++++++++++++++++++++++ # tokenizer_X=tokenizer_X, # Xnormfac=Xnormfac, # max_length=max_length, # ) # else: # # for ForcPath+ModelB: # # for model.sample() here using x_data_tokenized channel # x_data_tokenized=x_data_tokenized.to(torch.float) # shape [batch, max_seq_len] # # leave channel expansion for the self.sample() to handle # # + for debug: # if CKeys['Debug_TrainerPack']==3: # print("x_data_tokenized dim: ", x_data_tokenized.shape) # print("x_data_tokenized dtype: ", x_data_tokenized.dtype) # print("test x_data_tokenized!=None: ", x_data_tokenized!=None) # result_embedding=model.sample ( # x=X_cond, # stop_at_unet_number=train_unet_number , # cond_scale=this_cond_scale, # cond_scales , # x_data=None, # # ++ # x_data_tokenized=x_data_tokenized, # # # skip_steps=skip_steps, # inpaint_images = inpaint_images, # inpaint_masks = inpaint_masks, # inpaint_resample_times = inpaint_resample_times, # init_images = init_images,device=device, # # ++++++++++++++++++++++++++ # tokenizer_X=tokenizer_X, # Xnormfac=Xnormfac, # max_length=max_length, # ) # # handle the results: from embedding into AA # # ++ for pLM # # full record # # result_embedding as image.dim: [batch, channels, seq_len] # # result_tokens.dim: [batch, seq_len] # result_tokens,result_logits = convert_into_tokens( # pLM_Model, # result_embedding, # pLM_Model_Name, # ) # # +++++++++++++++++++++++++++++++++ # result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] # # + for debug # print('result dim: ', result.shape) # # plot sequence token code: esm (33 tokens), for one batch # fig=plt.figure() # for ii in range(lenn_val): # plt.plot ( # result[ii,0,:].cpu().detach().numpy(), # label= f'Predicted for Input#{str(ii)}' # ) # #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') # plt.legend() # outname = sample_dir+ f"DenovoInputXs_CondScale_No{str(idx_cond)}_Val_{str(this_cond_scale)}_{e}_{steps}.jpg" # #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") # if IF_showfig==1: # plt.show () # else: # plt.savefig(outname, dpi=200) # plt.close() # # translate result into AA # to_rev=result[:,0,:] # token (batch,seq_len) # if CKeys['Debug_TrainerPack']==3: # print("on foldable result: ", to_rev[0]) # print("on result_logits: ", result_logits[0]) # print("on mask: ", result_mask[0]) # a = decode_one_ems_token_rec_for_folding_with_mask( # to_rev[0], # result_logits[0], # pLM_alphabet, # pLM_Model, # result_mask[0], # ) # print('One resu: ', a) # y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( # to_rev, # result_logits, # pLM_alphabet, # pLM_Model, # result_mask, # ) # if CKeys['Debug_TrainerPack']==3: # print("on y_data_reversed[0]: ", y_data_reversed[0]) # ### reverse second structure input.... # if X_cond != None: # X_cond=torch.round(X_cond*Xnormfac) # to_rev=X_cond[:,:] # to_rev=to_rev.long().cpu().detach().numpy() # print (to_rev.shape) # X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) # for iii in range (len(y_data_reversed)): # X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") # if x_data !=None: # # work for second structure input.... # # work for ForcPath input... # X_data_reversed=x_data #is already in sequence fromat.. # # sections for each one result # for iisample in range(lenn_val): # print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample]) # out_nam_fasta=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.fasta' # write_fasta (y_data_reversed[iisample], out_nam_fasta) # fasta_file_list.append(out_nam_fasta) # # + for debug # print("================================================") # print("foldproteins: ", foldproteins) # if not foldproteins: # pdb_file=None # else: # if X_cond != None: # # not maintained # xbc=X_cond[iisample,:].cpu().detach().numpy() # out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})+f'_{flag}_{steps}' # if x_data !=None: # pass # # #xbc=x_data[iisample] # # # ---------------------------------- # # # this one can be too long for a name # # out_nam=x_data[iisample] # # # ++++++++++++++++++++++++++++++++++ # # # # # out_nam=iisample # tempname='temp' # pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( # sequence=y_data_reversed[iisample], # filename_out=tempname, # num_cycle=num_cycle, # flag=flag, # # +++++++++++++++++++ # # prefix=prefix, # prefix=sample_dir, # ) # # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' # # ------------------------------------------- # # this one can be too long for a name # # However, the input X is recorded in the code # # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb' # # +++++++++++++++++++++++++++++++++++++++++++ # out_nam=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.pdb' # # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' # # print('Debug 1: out: ', out_nam) # # print('Debug 2: in: ', pdb_file) # shutil.copy (pdb_file, out_nam) #source, dest # # shutil.copy (fasta_file, out_nam_fasta) # # cmd_line = 'cp ' + pdb_file + ' ' + out_nam # # print(cmd_line) # # os.popen(cmd_line) # # print('Debug 3') # # clean the slade to avoid mistakenly using the previous fasta file # os.remove (pdb_file) # os.remove (fasta_file) # pdb_file=out_nam # # fasta_file=out_nam_fasta # pdb_file_list.append(pdb_file) # # ++ write the input condtion as a reference: for ForcPath # out_nam_inX=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}_input.txt' # if torch.is_tensor(X_data_reversed[iisample]): # # for safety, not used usually # xbc=X_data_reversed[iisample].cpu().detach().numpy() # else: # xbc=X_data_reversed[iisample] # if tokenizer_X==None: # # for ForcPath case # out_inX=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) # else: # # for SecStr case # out_inX=xbc # with open(out_nam_inX, "w") as inX_file: # inX_file.write(out_inX) # print (f"Properly named PDB file produced: {pdb_file}") # if IF_showfig==1: # #flag=1000 # view=show_pdb( # pdb_file=pdb_file, # flag=flag, # show_sidechains=show_sidechains, # show_mainchains=show_mainchains, # color=color # ) # view.show() # return pdb_file_list, fasta_file_list # ++ # for ProteinDesigner def sample_loop_omegafold_pLM_ModelB ( model, train_loader, cond_scales=[7.5], #list of cond scales - each sampled... num_samples=2, #how many samples produced every time tested..... timesteps=100, # not used flag=0, foldproteins=False, use_text_embedd=True, skip_steps=0, # +++++++++++++++++++ train_unet_number=1, ynormfac=1, prefix=None, tokenizer_y=None, Xnormfac=1, tokenizer_X=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True # ++ pLM_Model=None, pLM_Model_Name=None, image_channels=None, pLM_alphabet=None, ): # ===================================================== # sample # = num_samples*(# of mini-batches) # ===================================================== # steps=0 # e=flag # for item in train_loader: for idx, item in enumerate(train_loader): X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) # -- # # ++ for pLM case: # if pLM_Model_Name=='None': # # just use the encoded sequence # # y_train_batch_in = y_train_batch.unsqueeze(1) # X_train_batch_in = X_train_batch.unsqueeze(1) # # pass # elif pLM_Model_Name=='esm2_t33_650M_UR50D': # # with torch.no_grad(): # # results = pLM_Model( # # y_train_batch, # # repr_layers=[33], # # return_contacts=False, # # ) # # y_train_batch_in = results["representations"][33] # # y_train_batch_in = rearrange( # # y_train_batch_in, # # 'b l c -> b c l' # # ) # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) GT=y_train_batch.cpu().detach() GT= GT.unsqueeze(1) if num_samples>y_train_batch.shape[0]: print("Warning: sampling # > len(mini_batch)") num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") X_train_batch_picked = X_train_batch[:num_samples,:] # X_train_batch_in[:num_samples ] # print ('After pLM, (TEST) X_batch shape: ', X_train_batch_picked.shape) for iisample in range (len (cond_scales)): if use_text_embedd: result_embedding=model.sample ( # x= X_train_batch, x= X_train_batch_picked, stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps ) else: result_embedding=model.sample ( x= None, # x_data_tokenized= X_train_batch, x_data_tokenized= X_train_batch_picked, # dim=(batch, seq_len), will extend channels inside .sample(), stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps ) # ++ for pLM: if image_channels==33: result_tokens,result_logits = convert_into_tokens_using_prob( result_embedding, pLM_Model_Name, ) else: # full record # result_embedding as image.dim: [batch, channels, seq_len] # result_tokens.dim: [batch, seq_len] result_tokens,result_logits = convert_into_tokens( pLM_Model, result_embedding, pLM_Model_Name, ) # # --------------------------------- # result=torch.round(result*ynormfac) # GT=torch.round (GT*ynormfac) # +++++++++++++++++++++++++++++++++ result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] # + # # ------------------------------------------- # #reverse y sequence # to_rev=result[:,0,:] # to_rev=to_rev.long().cpu().detach().numpy() # # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) # # for iii in range (len(y_data_reversed)): # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") # ++++++++++++++++++++++++++++++++++++++++++++ # extract the mask/seq_len from input if possible # here, from dataloader, we only use tokenized_data for mask generation result_mask = read_mask_from_input( tokenized_data=X_train_batch[:num_samples], mask_value=0, seq_data=None, max_seq_length=None, ) to_rev=result[:,0,:] # token (batch,seq_len) if CKeys['Debug_TrainerPack']==3: print("on foldable result: ", to_rev[0]) print("on result_logits: ", result_logits[0]) print("on mask: ", result_mask[0]) a = decode_one_ems_token_rec_for_folding_with_mask( to_rev[0], result_logits[0], pLM_alphabet, pLM_Model, result_mask[0], ) print('One resu: ', a) y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( to_rev, result_logits, pLM_alphabet, pLM_Model, result_mask, ) if CKeys['Debug_TrainerPack']==3: print("on y_data_reversed[0]: ", y_data_reversed[0]) # # ++++++++++++++++++++++++++++++++++++++++++++ # # reverse the PREDICTED y into a foldable sequence # # save this block for Model A # to_rev=result[:,0,:] # token (batch,seq_len) # y_data_reversed=decode_many_ems_token_rec_for_folding( # to_rev, # result_logits, # pLM_alphabet, # pLM_Model, # ) # if CKeys['Debug_TrainerPack']==3: # print("on foldable result: ", to_rev[0]) # print("on result_logits: ", result_logits[0]) # a = decode_one_ems_token_rec_for_folding( # to_rev[0], # result_logits[0], # pLM_alphabet, # pLM_Model, # ) # print('One resu: ', a) # print("on y_data_reversed: ", y_data_reversed[0]) # # ----------------------------------------------------- # #reverse GT_y sequence # to_rev=GT[:,0,:] # to_rev=to_rev.long().cpu().detach().numpy() # # GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) # # for iii in range (len(y_data_reversed)): # GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ #reverse GT_y sequence # GT should be SAFE to reverse to_rev=GT[:,0,:] # (batch,1,seq_len)->(batch, seq_len) GT_y_data_reversed=decode_many_ems_token_rec( to_rev, pLM_alphabet, ) # -- not for SecStr anymore # ### reverse second structure input.... # to_rev=torch.round (X_train_batch[:,:]*Xnormfac) # to_rev=to_rev.long().cpu().detach().numpy() # ++ ### reverse general float input... to_rev=X_train_batch[:,:]*Xnormfac to_rev=to_rev.cpu().detach().numpy() # here, assume X_train_batch is for cond_img: there are padding at both beginning and ending part # so, first move the 0th padding to the end: # Note: # 1. this is good for SecStr case: (not maintained here) # 2. this is not good for ForcPath, but can be cued in MD postprocess since the first component will always be 0 n_batch=to_rev.shape[0] n_embed=to_rev.shape[1] to_rev_1 = np.zeros(to_rev.shape) to_rev_1[:,0:n_embed-1]=to_rev[:,1:n_embed] # ++ different input if tokenizer_X!=None: # change into int to_rev_1 = np.round(to_rev_1) X_data_reversed=tokenizer_X.sequences_to_texts (to_rev_1) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") else: X_data_reversed=to_rev_1.copy() # + for debug if CKeys['Debug_TrainerPack']==1: print("X_data_reversed: ", X_data_reversed) for samples in range (num_samples): print ("sample ", samples+1, "out of ", num_samples) fig=plt.figure() plt.plot ( result[samples,0,:].cpu().detach().numpy(), label= f'Predicted' ) plt.plot ( GT[samples,0,:], label= f'GT {0}' ) plt.legend() outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" if IF_showfig==1: plt.show() else: plt.savefig(outname, dpi=200) plt.close () # # # ------------------------------------------- # # #reverse y sequence # # to_rev=result[:,0,:] # # to_rev=to_rev.long().cpu().detach().numpy() # # # # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) # # # # for iii in range (len(y_data_reversed)): # # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") # # ++++++++++++++++++++++++++++++++++++++++++++ # # extract the mask/seq_len from input if possible # # here, from dataloader, we only use tokenized_data for mask generation # result_mask = read_mask_from_input( # tokenized_data=X_train_batch[:num_samples], # mask_value=0, # seq_data=None, # max_seq_length=None, # ) # to_rev=result[:,0,:] # token (batch,seq_len) # if CKeys['Debug_TrainerPack']==3: # print("on foldable result: ", to_rev[0]) # print("on result_logits: ", result_logits[0]) # print("on mask: ", result_mask[0]) # a = decode_one_ems_token_rec_for_folding_with_mask( # to_rev[0], # result_logits[0], # pLM_alphabet, # pLM_Model, # result_mask[0], # ) # print('One resu: ', a) # y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( # to_rev, # result_logits, # pLM_alphabet, # pLM_Model, # result_mask, # ) # if CKeys['Debug_TrainerPack']==3: # print("on y_data_reversed[0]: ", y_data_reversed[0]) # # # ++++++++++++++++++++++++++++++++++++++++++++ # # # reverse the PREDICTED y into a foldable sequence # # # save this block for Model A # # to_rev=result[:,0,:] # token (batch,seq_len) # # y_data_reversed=decode_many_ems_token_rec_for_folding( # # to_rev, # # result_logits, # # pLM_alphabet, # # pLM_Model, # # ) # # if CKeys['Debug_TrainerPack']==3: # # print("on foldable result: ", to_rev[0]) # # print("on result_logits: ", result_logits[0]) # # a = decode_one_ems_token_rec_for_folding( # # to_rev[0], # # result_logits[0], # # pLM_alphabet, # # pLM_Model, # # ) # # print('One resu: ', a) # # print("on y_data_reversed: ", y_data_reversed[0]) # # # ----------------------------------------------------- # # #reverse GT_y sequence # # to_rev=GT[:,0,:] # # to_rev=to_rev.long().cpu().detach().numpy() # # # # GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) # # # # for iii in range (len(y_data_reversed)): # # GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") # # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # #reverse GT_y sequence # # GT should be SAFE to reverse # to_rev=GT[:,0,:] # GT_y_data_reversed=decode_many_ems_token_rec( # to_rev, # pLM_alphabet, # ) # ### reverse second structure input.... # to_rev=torch.round (X_train_batch[:,:]*Xnormfac) # to_rev=to_rev.long().cpu().detach().numpy() # # here, assume X_train_batch is for cond_img: there are padding at both beginning and ending part # # so, first move the 0th padding to the end # n_batch=to_rev.shape[0] # n_embed=to_rev.shape[1] # to_rev_1 = np.zeros(to_rev.shape) # to_rev_1[:,0:n_embed-1]=to_rev[:,1:n_embed] # # ++ different input # if tokenizer_X!=None: # X_data_reversed=tokenizer_X.sequences_to_texts (to_rev_1) # for iii in range (len(y_data_reversed)): # X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") # else: # X_data_reversed=to_rev_1.copy() # # + for debug # if CKeys['Debug_TrainerPack']==1: # print("X_data_reversed: ", X_data_reversed) # print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples]) print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} \nor\n {X_data_reversed[samples]}, ") print (f"predicted sequence: {y_data_reversed[samples]}") print (f"Ground truth: {GT_y_data_reversed[samples]}") error=string_diff (y_data_reversed[samples], GT_y_data_reversed[samples])/len (GT_y_data_reversed[samples]) print(f"Recovery ratio(Ref): {1.-error}") # move some # # -- X_train_batch is normalized # xbc=X_train_batch[samples,:].cpu().detach().numpy() # # out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) # out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) # ++ xbc = X_data_reversed[samples] if type(xbc)==str: out_nam_content=xbc else: out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) # 1. write out the input X in the dataloder out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' # + write the condition clearly # X_data_reversed: an array with open(out_nam_inX, "w") as inX_file: # inX_file.write(f'{X_data_reversed[samples]}\n') inX_file.write(out_nam_content) # 2. write out the predictions out_nam_OuY_PR=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}_predict.fasta' with open(out_nam_OuY_PR, "w") as ouY_fasta: ouY_fasta.write(f">Predicted\n") ouY_fasta.write(y_data_reversed[samples]) # 3. Only for dataloader: write out the recovered ground truth out_nam_OuY_GT=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}_recGT.fasta' with open(out_nam_OuY_GT, "w") as ouY_fasta: ouY_fasta.write(f">reconstructed GT, recoverabliblity: {1.-error}\n") ouY_fasta.write(GT_y_data_reversed[samples]) if foldproteins: tempname='temp' pdb_file,fasta_file=foldandsavePDB_pdb_fasta ( sequence=y_data_reversed[samples], filename_out=tempname, num_cycle=16, flag=flag, # +++++++++++++++++++ prefix=prefix ) # #out_nam=f'{prefix}{out_nam}.pdb' # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' # ------------------------------------------------------ # sometime, this name below can get too long to fit # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb' # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # add a way to save the sampling name and results # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' out_nam_seq=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.fasta' if CKeys['Debug_TrainerPack']==1: print("pdb_file: ", pdb_file) print("out_nam: ", out_nam) print (f'Original PDB: {pdb_file} OUT: {out_nam}') shutil.copy (pdb_file, out_nam) #source, dest shutil.copy (fasta_file, out_nam_seq) # clean the slade to avoid mistakenly using the previous fasta file os.remove (pdb_file) os.remove (fasta_file) pdb_file=out_nam print (f"Properly named PDB file produced: {pdb_file}") print (f"input X for sampling stored: {pdb_file}") if IF_showfig==1: view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() # ++ # For ProteinPredictor def sample_loop_omegafold_pLM_ModelB_Predictor ( model, train_loader, cond_scales=[7.5], #list of cond scales - each sampled... num_samples=2, #how many samples produced every time tested..... timesteps=100, # not used flag=0, foldproteins=False, use_text_embedd=True, skip_steps=0, # +++++++++++++++++++ train_unet_number=1, ynormfac=1, prefix=None, tokenizer_y=None, Xnormfac=1, tokenizer_X=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True # ++ pLM_Model=None, pLM_Model_Name=None, image_channels=None, pLM_alphabet=None, # ++ esm_layer=None, ): # ===================================================== # sample # = num_samples*(# of mini-batches) # ===================================================== # steps=0 # e=flag # for item in train_loader: val_epoch_MSE_list=[] resu_pred = {} resu_grou = {} # for iisample in range (len (cond_scales)): # calculate loss for one selected cond_scales # val_epoch_MSE=0. num_rec=0 this_prediction = [] this_groundtruth = [] # for idx, item in enumerate(train_loader): X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 1. adjust the number of sample to collect in each batch if num_samples>y_train_batch.shape[0]: print("Warning: sampling # > len(mini_batch)") num_samples = min (num_samples,y_train_batch.shape[0]) print (f"Producing {num_samples} samples...") X_train_batch_picked = X_train_batch[:num_samples,:] # X_train_batch_in[:num_samples ] # GT=y_train_batch.cpu().detach() GT_picked = GT[:num_samples,:] # GT_picked = GT_picked.unsqueeze(1) # 2. prepare if pLM is used at the input end: # this is done inised model.sample fun via x_data_tokenized channel # # 3. sample inside the loop of cond_scales if use_text_embedd: result_embedding=model.sample ( # x= X_train_batch, x= X_train_batch_picked, stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps, # ++ pLM_Model_Name=pLM_Model_Name, pLM_Model=pLM_Model, pLM_alphabet=pLM_alphabet, esm_layer=esm_layer, ) else: result_embedding=model.sample ( x= None, # x_data_tokenized= X_train_batch, x_data_tokenized= X_train_batch_picked, # dim=(batch, seq_len), will extend channels inside .sample(), stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps, # ++ pLM_Model_Name=pLM_Model_Name, pLM_Model=pLM_Model, pLM_alphabet=pLM_alphabet, esm_layer=esm_layer, ) # result_embedding as image.dim: [batch, channels, seq_len] # # 4. translate prediction into something meaningful # consider channel average and masking # average across channels result_embedding=torch.mean(result_embedding, 1) # (batch, seq_len) # read mask from input: X_train_batch_picked (batch, seq_len) # result_mask looks like, 0,1,1,...,1,0,0 # will fill 0th component be zero result_mask = read_mask_from_input( tokenized_data=X_train_batch[:num_samples], mask_value=0.0, seq_data=None, max_seq_length=None, ) # apply mask to result: keep true and zero all false result = result_embedding*result_mask # (batch, seq_len) result = result.cpu() # result = result.unsqueeze(1) # (batch, 1, seq_len) # this is ONLY the result from the model, not predictio yet # # 5. calculate loss with torch.no_grad(): val_loss_MSE = criterion_MSE_sum( result, GT_picked, ) val_epoch_MSE += val_loss_MSE.item()/GT_picked.shape[1] num_rec += len(GT_picked) # # 6. convert into prediction y_data_reversed = result*ynormfac # prepare GT GT_y_data_reversed = GT_picked*ynormfac # accumulate the results for ibat in range (GT_picked.shape[0]): this_prediction.append (np.array( y_data_reversed[ibat,:].cpu() )) this_groundtruth.append (np.array( GT_y_data_reversed[ibat,:].cpu() )) # # 5. reverse input to AA sequence... if needed # TBA # for one scal_cond # summarize the loss TestSet_MSE = val_epoch_MSE/num_rec resu_pred[str(cond_scales[iisample])] = this_prediction resu_grou[str(cond_scales[iisample])] = this_groundtruth # store the MSE along cond_scales val_epoch_MSE_list.append(TestSet_MSE) return val_epoch_MSE_list, resu_pred, resu_grou # ++++++++++++++++++++++++++++++++++++++++++++++++ def sample_loop_omegafold_ModelB ( model, train_loader, cond_scales=[7.5], #list of cond scales - each sampled... num_samples=2, #how many samples produced every time tested..... timesteps=100, # not used flag=0, foldproteins=False, use_text_embedd=True, skip_steps=0, # +++++++++++++++++++ train_unet_number=1, ynormfac=1, prefix=None, tokenizer_y=None, Xnormfac=1, tokenizer_X=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True ): # ===================================================== # sample # = num_samples*(# of mini-batches) # ===================================================== # steps=0 # e=flag # for item in train_loader: for idx, item in enumerate(train_loader): X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) GT=y_train_batch.cpu().detach() GT= GT.unsqueeze(1) if num_samples>y_train_batch.shape[0]: print("Warning: sampling # > len(mini_batch)") num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") X_train_batch_picked = X_train_batch[:num_samples,:] print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) for iisample in range (len (cond_scales)): if use_text_embedd: result=model.sample ( # x= X_train_batch, x= X_train_batch_picked, stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps ) else: result=model.sample ( x= None, # x_data_tokenized= X_train_batch, x_data_tokenized= X_train_batch_picked, stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps ) result=torch.round(result*ynormfac) GT=torch.round (GT*ynormfac) for samples in range (num_samples): print ("sample ", samples+1, "out of ", num_samples) fig=plt.figure() plt.plot ( result[samples,0,:].cpu().detach().numpy(), label= f'Predicted' ) plt.plot ( GT[samples,0,:], label= f'GT {0}' ) plt.legend() outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" if IF_showfig==1: plt.show() else: plt.savefig(outname, dpi=200) plt.close () #reverse y sequence to_rev=result[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") #reverse GT_y sequence to_rev=GT[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") ### reverse second structure input.... to_rev=torch.round (X_train_batch[:,:]*Xnormfac) to_rev=to_rev.long().cpu().detach().numpy() # ++ different input if tokenizer_X!=None: X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") else: X_data_reversed=to_rev.copy() # + for debug if CKeys['Debug_TrainerPack']==1: print("X_data_reversed: ", X_data_reversed) print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples]) print (f"Ground truth: {GT_y_data_reversed[samples]}") if foldproteins: xbc=X_train_batch[samples,:].cpu().detach().numpy() out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) tempname='temp' pdb_file=foldandsavePDB ( sequence=y_data_reversed[samples], filename_out=tempname, num_cycle=16, flag=flag, # +++++++++++++++++++ prefix=prefix ) # #out_nam=f'{prefix}{out_nam}.pdb' # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' # ------------------------------------------------------ # sometime, this name below can get too long to fit # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb' # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # add a way to save the sampling name and results # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' if CKeys['Debug_TrainerPack']==1: print("pdb_file: ", pdb_file) print("out_nam: ", out_nam) print (f'Original PDB: {pdb_file} OUT: {out_nam}') shutil.copy (pdb_file, out_nam) #source, dest # + with open(out_nam_inX, "w") as inX_file: inX_file.write(f'{X_data_reversed[samples]}\n') pdb_file=out_nam print (f"Properly named PDB file produced: {pdb_file}") print (f"input X for sampling stored: {pdb_file}") if IF_showfig==1: view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() # steps=steps+1 # if steps>num_samples: # break # # def sample_loop_FromModelB (model, train_loader, cond_scales=[7.5], #list of cond scales - each sampled... num_samples=2, #how many samples produced every time tested..... timesteps=100, flag=0,foldproteins=False, use_text_embedd=True,skip_steps=0, # +++++++++++++++++++ train_unet_number=1, ynormfac=1, prefix=None, tokenizer_y=None, Xnormfac=1, tokenizer_X=None, ): steps=0 e=flag for item in train_loader: X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) GT=y_train_batch.cpu().detach() GT= GT.unsqueeze(1) num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") print ('X_train_batch shape: ', X_train_batch.shape) for iisample in range (len (cond_scales)): if use_text_embedd: result=model.sample (x= X_train_batch,stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps) else: result=model.sample (x= None, x_data_tokenized= X_train_batch, stop_at_unet_number=train_unet_number , cond_scale=cond_scales[iisample],device=device,skip_steps=skip_steps) result=torch.round(result*ynormfac) GT=torch.round (GT*ynormfac) for samples in range (num_samples): print ("sample ", samples, "out of ", num_samples) plt.plot (result[samples,0,:].cpu().detach().numpy(),label= f'Predicted') plt.plot (GT[samples,0,:],label= f'GT {0}') plt.legend() outname = prefix+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" plt.savefig(outname, dpi=200) plt.show () #reverse y sequence to_rev=result[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") #reverse GT_y sequence to_rev=GT[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") ### reverse second structure input.... to_rev=torch.round (X_train_batch[:,:]*Xnormfac) to_rev=to_rev.long().cpu().detach().numpy() X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, predicted sequence: ", y_data_reversed[samples]) print (f"Ground truth: {GT_y_data_reversed[samples]}") if foldproteins: xbc=X_train_batch[samples,:].cpu().detach().numpy() out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) tempname='temp' pdb_file=foldandsavePDB ( sequence=y_data_reversed[samples], filename_out=tempname, num_cycle=16, flag=flag, # +++++++++++++++++++ prefix=prefix ) #out_nam=f'{prefix}{out_nam}.pdb' out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' print (f'Original PDB: {pdb_file} OUT: {out_nam}') shutil.copy (pdb_file, out_nam) #source, dest pdb_file=out_nam print (f"Properly named PDB file produced: {pdb_file}") view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) view.show() steps=steps+1 if steps>num_samples: break # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # train_loop tasks: # 1. calculate loss for one batch # 2. call sample loop # 3. call sample sequence # 4. print records and save model # =============================================================== # for ProteinDesigner_B # 1. expanded for Probability case # ++ cal_norm_prob = nn.Softmax(dim=2) def train_loop_Model_B ( model, train_loader, test_loader, # optimizer=None, print_every=10, epochs= 300, start_ep=0, start_step=0, train_unet_number=1, print_loss_every_steps=1000, # trainer=None, plot_unscaled=False, max_batch_size=4, save_model=False, cond_scales=[7.5], #list of cond scales - each sampled... num_samples=2, #how many samples produced every time tested..... foldproteins=False, cond_image=False, #use cond_images... # add some # +++++++++++++++++++++++++++ device=None, loss_list=[], epoch_list=[], train_hist_file=None, train_hist_file_full=None, prefix=None, Xnormfac=1., ynormfac=1., tokenizer_X=None, tokenizer_y=None, test_condition_list=[], max_length=1, CKeys=None, sample_steps=1, sample_dir=None, save_every_epoch=1, save_point_info_file=None, store_dir=None, # ++ pLM_Model_Name=None, image_channels=None, print_error=False, ): if not exists (trainer): if not exists (optimizer): print ("ERROR: If trainer not used, need to provide optimizer.") if exists (trainer): print ("Trainer provided... will be used") # steps=start_step+1 # # + # added_steps=0+1 # steps=start_step added_steps=0 loss_total=0 # ++ for pLM if pLM_Model_Name=='None': pLM_Model=None elif pLM_Model_Name=='esm2_t33_650M_UR50D': # dim: 1280 esm_layer=33 pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() len_toks=len(esm_alphabet.all_toks) pLM_Model.eval() pLM_Model. to(device) elif pLM_Model_Name=='esm2_t36_3B_UR50D': # dim: 2560 esm_layer=36 pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() len_toks=len(esm_alphabet.all_toks) pLM_Model.eval() pLM_Model. to(device) elif pLM_Model_Name=='esm2_t30_150M_UR50D': # dim: 640 esm_layer=30 pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() len_toks=len(esm_alphabet.all_toks) pLM_Model.eval() pLM_Model. to(device) elif pLM_Model_Name=='esm2_t12_35M_UR50D': # dim: 480 esm_layer=12 pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() len_toks=len(esm_alphabet.all_toks) pLM_Model.eval() pLM_Model. to(device) else: print("pLM model is missing...") for e in range(1, epochs+1): # start = time.time() torch.cuda.empty_cache() print ("######################################################################################") start = time.time() print ("NOW: Training epoch: ", e+start_ep) # TRAINING train_epoch_loss = 0 model.train() print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") for item in train_loader: steps += 1 added_steps += 1 X_train_batch= item[0].to(device) y_train_batch= item[1].to(device) # project y_ into embedding space if CKeys["Debug_TrainerPack"]==1: print("Initial unload the dataloader items: ...") print("X_train_batch.dim: ", X_train_batch.shape) print("y_train_batch.dim: ", y_train_batch.shape) # --------------------------------------------------------- # prepare for model.forward() to calculate loss # --------------------------------------------------------- # # -- # if pLM_Model_Name=='None': # # just use the encoded sequence # y_train_batch_in = y_train_batch.unsqueeze(1) # X_train_batch_in = X_train_batch.unsqueeze(1) # # pass # elif pLM_Model_Name=='esm2_t33_650M_UR50D': # with torch.no_grad(): # results = pLM_Model( # y_train_batch, # repr_layers=[33], # return_contacts=False, # ) # y_train_batch_in = results["representations"][33] # (batch, seq_len, channels) # y_train_batch_in = rearrange( # y_train_batch_in, # 'b l c -> b c l' # ) # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) # else: # print(f"Required pLM name is not defined!!") # ++ if pLM_Model_Name=='None': # just use the encoded sequence y_train_batch_in = y_train_batch.unsqueeze(1) X_train_batch_in = X_train_batch.unsqueeze(1) # pass else: # assume ESM models with torch.no_grad(): results = pLM_Model( y_train_batch, repr_layers=[esm_layer], return_contacts=False, ) y_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels) # ++ for Probability case if image_channels==33: with torch.no_grad(): # calculate logits: (batch, seq_len, 33) y_train_batch_in = pLM_Model.lm_head( y_train_batch_in ) # normalize to get (0,1) probability y_train_batch_in = cal_norm_prob(y_train_batch_in) # switch the dimension -> (batch, channel, seq_len) y_train_batch_in = rearrange( y_train_batch_in, 'b l c -> b c l' ) X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) # + for debug if CKeys["Debug_TrainerPack"]==1: print("After pLM model, the shape of X and y for training:") print("X_train_batch_in.dim: ", X_train_batch_in.shape) print("y_train_batch_in.dim: ", y_train_batch_in.shape) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if exists (trainer): if cond_image==False: loss = trainer( y_train_batch.unsqueeze(1) , # true image (batch, channels, seq_len) x=X_train_batch, # tokenized text (batch, ) unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) # # ---------------------------------------------------------- # if cond_image==True: # loss = trainer( # y_train_batch.unsqueeze(1) , # true image # x=None, # tokenized text # cond_images=X_train_batch.unsqueeze(1), # cond_image # unet_number=train_unet_number, # max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory # ) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if cond_image==True: loss = trainer( y_train_batch_in, # true image x=None, # tokenized text cond_images=X_train_batch_in, # cond_image unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) trainer.update(unet_number = train_unet_number) else: optimizer.zero_grad() if cond_image==False: loss=model ( y_train_batch.unsqueeze(1) , x=X_train_batch, unet_number=train_unet_number ) # # ------------------------------------------------------ # if cond_image==True: # loss=model ( # y_train_batch.unsqueeze(1) , # x=None, # cond_images=X_train_batch.unsqueeze(1), # unet_number=train_unet_number # ) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ if cond_image==True: loss=model ( y_train_batch_in , x=None, cond_images=X_train_batch_in, unet_number=train_unet_number ) # loss.backward( ) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() loss_total=loss_total+loss.item() # + train_epoch_loss=train_epoch_loss+loss.item() if steps % print_every == 0: # for progress bar print(".", end="") # if steps>0: if added_steps>0: if steps % print_loss_every_steps == 0: # + for debug if CKeys['Debug_TrainerPack']==2: print('I am here') print("Here is steps: ", steps) norm_loss=loss_total/print_loss_every_steps print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # add a line to the hist file add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' with open(train_hist_file,'a') as f: f.write(add_line) loss_list.append (norm_loss) loss_total=0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ epoch_list.append(e+start_ep) # fig = plt.figure(figsize=(12,8),dpi=200) fig = plt.figure() plt.plot (epoch_list, loss_list, label='Loss') plt.legend() # outname = prefix+ f"loss_{e+start_ep}_{steps}.jpg" outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" # # the order, save then show, matters if CKeys['SlientRun']==1: plt.savefig(outname, dpi=200) else: plt.show() plt.close(fig) # plt.close() if added_steps>0: # if steps>0: if steps % sample_steps == 0: # + for debug if CKeys['Debug_TrainerPack']==2: print('I am here') print("Here is steps: ", steps) if plot_unscaled: #test before scaling... plt.plot ( y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), label= 'Unscaled GT' ) plt.legend() plt.show() # # -------------------------------------------------- # GT=y_train_batch.cpu().detach() # GT=resize_image_to( # GT.unsqueeze(1), # model.imagen.image_sizes[train_unet_number-1], # ) #### print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") print ("I. SAMPLING IN TEST SET: ") print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") #### # num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") if cond_image == True: use_text_embedd=False # - # cond_scales_extended=[1. for i in range(num_samples)] # + cond_scales_extended=cond_scales else: use_text_embedd=True cond_scales_extended=cond_scales # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sample_loop_omegafold_pLM_ModelB ( model, test_loader, cond_scales=cond_scales_extended, # cond_scales,# #list of cond scales - each sampled... num_samples=num_samples, #how many samples produced every time tested..... timesteps=64, flag=steps, #reverse=False, foldproteins=foldproteins, use_text_embedd= use_text_embedd, # ++++++++++++++++++++ train_unet_number=train_unet_number, ynormfac=ynormfac, prefix=prefix, tokenizer_y=tokenizer_y, Xnormfac=Xnormfac, tokenizer_X=tokenizer_X, # ++ # ++ CKeys=CKeys, sample_dir=sample_dir, steps=steps, e=e+start_ep, IF_showfig= CKeys['SlientRun']!=1, # ++ pLM_Model=pLM_Model, pLM_Model_Name=pLM_Model_Name, image_channels=image_channels, pLM_alphabet=esm_alphabet, ) #--------------------------------------------------------- # sample_loop ( # model, # test_loader, # cond_scales=cond_scales,# #list of cond scales - each sampled... # num_samples=num_samples, #how many samples produced every time tested..... # timesteps=64, # flag=steps, # #reverse=False, # foldproteins=foldproteins, # use_text_embedd= use_text_embedd, # # ++++++++++++++++++++ # train_unet_number=train_unet_number, # ynormfac=ynormfac, # prefix=prefix, # tokenizer_y=tokenizer_y, # Xnormfac=Xnormfac, # tokenizer_X=tokenizer_X, # ) #index_word': '{"1": "~", "2": "h", "3": "e", "4": "s", "5": "t", "6": "g", "7": "b", "8": "i"}', #'word_index': '{"~": 1, "h": 2, "e": 3, "s": 4, "t": 5, "g": 6, "b": 7, "i": 8}'} AH_code=2/Xnormfac BS_code=3/Xnormfac unstr_code= 1/Xnormfac print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") print ("II. SAMPLING FOR DE NOVO:") print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") # +++++++++++++++++++++++++++++++++++++++++ DeNovoSam_pdbs, fasta_file_list=\ sample_sequence_omegafold_pLM_ModelB ( model, x_data=test_condition_list, flag=steps, # flag="DeNovo", # , cond_scales=1., foldproteins=foldproteins, # ++++++++++ ynormfac=ynormfac, train_unet_number=train_unet_number, tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, prefix=prefix, tokenizer_y=tokenizer_y, # ++ CKeys=CKeys, sample_dir=sample_dir, steps=steps, e=e+start_ep, IF_showfig= CKeys['SlientRun']!=1, # ++ pLM_Model=pLM_Model, pLM_Model_Name=pLM_Model_Name, image_channels=image_channels, pLM_alphabet=esm_alphabet, ) if print_error and len(DeNovoSam_pdbs)>0: print("Calculate SecStr and design error:") # for ii in range(len(test_condition_list)): seq=test_condition_list[ii][0] DSSPresult,_,sequence_res=get_DSSP_result(DeNovoSam_pdbs[ii]) print (f"INPUT: {seq}\nRESULT: {DSSPresult}\nAA sequence: {sequence_res}") error=string_diff (DSSPresult, seq)/len (seq) print ("Error: ", error) # # +++++++++++++++++++++++++++++++++++++++++ # sample_sequence_omegafold_ModelB ( # model, # x_data=test_condition_list, # flag=steps, # flag="DeNovo", # , # cond_scales=1., # foldproteins=foldproteins, # # ++++++++++ # ynormfac=ynormfac, # train_unet_number=train_unet_number, # tokenizer_X=tokenizer_X, # Xnormfac=Xnormfac, # max_length=max_length, # prefix=prefix, # tokenizer_y=tokenizer_y, # # ++ # CKeys=CKeys, # sample_dir=sample_dir, # steps=steps, # e=e+start_ep, # IF_showfig= CKeys['SlientRun']!=1, # ) # for this_x_data in test_condition_list: # sample_sequence_omegafold ( # model, # x_data=this_x_data, # flag=steps, # flag="DeNovo", # , # cond_scales=1., # foldproteins=True, # # ++++++++++ # ynormfac=ynormfac, # train_unet_number=train_unet_number, # tokenizer_X=tokenizer_X, # Xnormfac=Xnormfac, # max_length=max_length, # prefix=prefix, # tokenizer_y=tokenizer_y, # # ++ # CKeys=CKeys, # sample_dir=sample_dir, # steps=steps, # e=e+start_ep, # IF_showfig= CKeys['SlientRun']!=1, # ) # # model, # X=None, #this is the target conventionally when using text embd # flag=0, # cond_scales=1., # foldproteins=False, # X_string=None, # x_data=None, # skip_steps=0, # inpaint_images = None, # inpaint_masks = None, # inpaint_resample_times = None, # init_images = None, # num_cycle=16, # # ++++++++++++++++++++++++ # ynormfac=1, # train_unet_number=1, # tokenizer_X=None, # Xnormfac=1., # max_length=1., # prefix=None, # tokenizer_y=None, # # ++ # CKeys=None, # sample_dir=None, # ----------------------------------------- # sample_sequence ( # model, # x_data=['~~~HHHHHHHHHHHHHHH~~'], # flag=steps,cond_scales=1., # foldproteins=True, # # ++++++++++ # ynormfac=ynormfac, # ) # sample_sequence ( # model, # x_data=['~~~HHHHHHHHHHHHHHH~~~~HHHHHHHHHHHHHH~~~'], # flag=steps,cond_scales=1., # foldproteins=True, # # ++++++++++ # ynormfac=ynormfac, # ) # sample_sequence ( # model, # x_data=['~~EEESSTTS~SEEEEEEEEE~SBS~EEEEEE~~'], # flag=steps,cond_scales=1., # foldproteins=True, # # ++++++++++++ # ynormfac=ynormfac, # ) # if steps>0: # # -------------------------------------------------------------------- # if added_steps>0: # if save_model and steps % print_loss_every_steps==0: # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" # trainer.save(fname) # print (f"Model saved: ", fname) # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" # torch.save(model.state_dict(), fname) # print (f"Statedict model saved: ", fname) # steps=steps+1 # added_steps += 1 # every epoch: norm_loss_over_e = train_epoch_loss/len(train_loader) print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # write this into "train_hist_file_full" add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' with open(train_hist_file_full,'a') as f: f.write(add_line) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # save model every this epoches if save_model and (e+start_ep) % save_every_epoch==0 and e>1: # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" trainer.save(fname) print (f"Model saved: ", fname) # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" torch.save(model.state_dict(), fname) print (f"Statedict model saved: ", fname) # add a saving point file top_line='epoch,steps,norm_loss'+'\n' add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' with open(save_point_info_file, "w") as f: f.write(top_line) f.write(add_line) print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") # =============================================================== # for ProteinPredictor_B # def train_loop_Model_B_Predictor ( model, train_loader, test_loader, # optimizer=None, print_every=10, epochs= 300, start_ep=0, start_step=0, train_unet_number=1, print_loss_every_steps=1000, # trainer=None, plot_unscaled=False, max_batch_size=4, save_model=False, cond_scales=[1.], #list of cond scales - each sampled... num_samples=2, #how many samples produced every time tested..... foldproteins=False, cond_image=False, #use cond_images... # add some # +++++++++++++++++++++++++++ device=None, loss_list=[], epoch_list=[], train_hist_file=None, train_hist_file_full=None, prefix=None, Xnormfac=1., ynormfac=1., tokenizer_X=None, tokenizer_y=None, test_condition_list=[], max_length=1, CKeys=None, sample_steps=1, sample_dir=None, save_every_epoch=1, save_point_info_file=None, store_dir=None, # ++ pLM_Model_Name=None, image_channels=None, print_error=False, # ++ train_hist_file_on_testset=None, ): if not exists (trainer): if not exists (optimizer): print ("ERROR: If trainer not used, need to provide optimizer.") if exists (trainer): print ("Trainer provided... will be used") # steps=start_step+1 # # + # added_steps=0+1 # steps=start_step added_steps=0 loss_total=0 # ++ for pLM # # -- # if pLM_Model_Name=='trivial': # pLM_Model=None # elif pLM_Model_Name=='esm2_t33_650M_UR50D': # # dim: 1280 # esm_layer=33 # pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() # len_toks=len(esm_alphabet.all_toks) # pLM_Model.eval() # pLM_Model. to(device) # elif pLM_Model_Name=='esm2_t36_3B_UR50D': # # dim: 2560 # esm_layer=36 # pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() # len_toks=len(esm_alphabet.all_toks) # pLM_Model.eval() # pLM_Model. to(device) # elif pLM_Model_Name=='esm2_t30_150M_UR50D': # # dim: 640 # esm_layer=30 # pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() # len_toks=len(esm_alphabet.all_toks) # pLM_Model.eval() # pLM_Model. to(device) # elif pLM_Model_Name=='esm2_t12_35M_UR50D': # # dim: 480 # esm_layer=12 # pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() # len_toks=len(esm_alphabet.all_toks) # pLM_Model.eval() # pLM_Model. to(device) # else: # print("pLM model is missing...") # ++ pLM_Model, esm_alphabet, \ esm_layer, len_toks = load_in_pLM( pLM_Model_Name, device, ) for e in range(1, epochs+1): # start = time.time() torch.cuda.empty_cache() print ("######################################################################################") start = time.time() print ("NOW: Training epoch: ", e+start_ep) # TRAINING train_epoch_loss = 0 model.train() print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") for item in train_loader: steps += 1 added_steps += 1 X_train_batch= item[0].to(device) y_train_batch= item[1].to(device) # project y_ into embedding space if CKeys["Debug_TrainerPack"]==1: print("Initial unload the dataloader items: ...") print("X_train_batch.dim: ", X_train_batch.shape) print("y_train_batch.dim: ", y_train_batch.shape) # --------------------------------------------------------- # prepare for model.forward() to calculate loss # --------------------------------------------------------- # # -- # if pLM_Model_Name=='None': # # just use the encoded sequence # y_train_batch_in = y_train_batch.unsqueeze(1) # X_train_batch_in = X_train_batch.unsqueeze(1) # # pass # elif pLM_Model_Name=='esm2_t33_650M_UR50D': # with torch.no_grad(): # results = pLM_Model( # y_train_batch, # repr_layers=[33], # return_contacts=False, # ) # y_train_batch_in = results["representations"][33] # (batch, seq_len, channels) # y_train_batch_in = rearrange( # y_train_batch_in, # 'b l c -> b c l' # ) # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) # else: # print(f"Required pLM name is not defined!!") # ++ if pLM_Model_Name=='trivial': # just use the encoded sequence y_train_batch_in = y_train_batch.unsqueeze(1) X_train_batch_in = X_train_batch.unsqueeze(1) # pass else: # assume ESM models # -- # # for ProteinDesigner # with torch.no_grad(): # results = pLM_Model( # y_train_batch, # repr_layers=[esm_layer], # return_contacts=False, # ) # y_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels) # y_train_batch_in = rearrange( # y_train_batch_in, # 'b l c -> b c l' # ) # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) # # ++ # for ProteinPredictor with torch.no_grad(): results = pLM_Model( X_train_batch, repr_layers=[esm_layer], return_contacts=False, ) X_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels) X_train_batch_in = rearrange( X_train_batch_in, 'b l c -> b c l' ) y_train_batch_in = y_train_batch.unsqueeze(1).repeat(1,image_channels,1) # + for debug if CKeys["Debug_TrainerPack"]==1: print("After pLM model, the shape of X and y for training:") print("X_train_batch_in.dim: ", X_train_batch_in.shape) print("y_train_batch_in.dim: ", y_train_batch_in.shape) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if exists (trainer): if cond_image==False: loss = trainer( y_train_batch.unsqueeze(1) , # true image (batch, channels, seq_len) x=X_train_batch, # tokenized text (batch, ) unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) # # ---------------------------------------------------------- # if cond_image==True: # loss = trainer( # y_train_batch.unsqueeze(1) , # true image # x=None, # tokenized text # cond_images=X_train_batch.unsqueeze(1), # cond_image # unet_number=train_unet_number, # max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory # ) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if cond_image==True: loss = trainer( y_train_batch_in, # true image x=None, # tokenized text cond_images=X_train_batch_in, # cond_image unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) trainer.update(unet_number = train_unet_number) else: optimizer.zero_grad() if cond_image==False: loss=model ( y_train_batch.unsqueeze(1) , x=X_train_batch, unet_number=train_unet_number ) # # ------------------------------------------------------ # if cond_image==True: # loss=model ( # y_train_batch.unsqueeze(1) , # x=None, # cond_images=X_train_batch.unsqueeze(1), # unet_number=train_unet_number # ) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ if cond_image==True: loss=model ( y_train_batch_in , x=None, cond_images=X_train_batch_in, unet_number=train_unet_number ) # loss.backward( ) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() loss_total=loss_total+loss.item() # + train_epoch_loss=train_epoch_loss+loss.item() if steps % print_every == 0: # for progress bar print(".", end="") # if steps>0: if added_steps>0: if steps % print_loss_every_steps == 0: # + for debug if CKeys['Debug_TrainerPack']==2: print('I am here') print("Here is steps: ", steps) norm_loss=loss_total/print_loss_every_steps print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # add a line to the hist file add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' with open(train_hist_file,'a') as f: f.write(add_line) loss_list.append (norm_loss) loss_total=0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ epoch_list.append(e+start_ep) # fig = plt.figure(figsize=(12,8),dpi=200) fig = plt.figure() plt.plot (epoch_list, loss_list, label='Loss') plt.legend() # outname = prefix+ f"loss_{e+start_ep}_{steps}.jpg" outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" # # the order, save then show, matters if CKeys['SlientRun']==1: plt.savefig(outname, dpi=200) else: plt.show() plt.close(fig) # plt.close() if added_steps>0: # if steps>0: if steps % sample_steps == 0: # + for debug if CKeys['Debug_TrainerPack']==2: print('I am here') print("Here is steps: ", steps) if plot_unscaled: #test before scaling... plt.plot ( y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), label= 'Unscaled GT' ) plt.legend() plt.show() # # -------------------------------------------------- # GT=y_train_batch.cpu().detach() # GT=resize_image_to( # GT.unsqueeze(1), # model.imagen.image_sizes[train_unet_number-1], # ) #### print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") print ("I. SAMPLING IN TEST SET: ") print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") #### # num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") if cond_image == True: use_text_embedd=False # - # cond_scales_extended=[1. for i in range(num_samples)] # + cond_scales_extended=cond_scales else: use_text_embedd=True cond_scales_extended=cond_scales # # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # # For ProteinDesigner # sample_loop_omegafold_pLM_ModelB ( # model, # test_loader, # cond_scales=cond_scales_extended, # cond_scales,# #list of cond scales - each sampled... # num_samples=num_samples, #how many samples produced every time tested..... # timesteps=64, # flag=steps, # #reverse=False, # foldproteins=foldproteins, # use_text_embedd= use_text_embedd, # # ++++++++++++++++++++ # train_unet_number=train_unet_number, # ynormfac=ynormfac, # prefix=prefix, # tokenizer_y=tokenizer_y, # Xnormfac=Xnormfac, # tokenizer_X=tokenizer_X, # # ++ # # ++ # CKeys=CKeys, # sample_dir=sample_dir, # steps=steps, # e=e+start_ep, # IF_showfig= CKeys['SlientRun']!=1, # # ++ # pLM_Model=pLM_Model, # pLM_Model_Name=pLM_Model_Name, # image_channels=image_channels, # pLM_alphabet=esm_alphabet, # ) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # For ProteinPredictor: val_epoch_MSE_list, \ resu_pred, resu_grou = \ sample_loop_omegafold_pLM_ModelB_Predictor ( model, test_loader, cond_scales=[1.], # cond_scales_extended, # #list of cond scales - each sampled... num_samples=num_samples, #how many samples produced every time tested..... timesteps=64, flag=steps, #reverse=False, foldproteins=foldproteins, use_text_embedd= use_text_embedd, # ++++++++++++++++++++ train_unet_number=train_unet_number, ynormfac=ynormfac, prefix=prefix, tokenizer_y=tokenizer_y, Xnormfac=Xnormfac, tokenizer_X=tokenizer_X, # ++ # ++ CKeys=CKeys, sample_dir=sample_dir, steps=steps, e=e+start_ep, IF_showfig= CKeys['SlientRun']!=1, # ++ pLM_Model=pLM_Model, pLM_Model_Name=pLM_Model_Name, image_channels=image_channels, pLM_alphabet=esm_alphabet, # ++ esm_layer=esm_layer, ) # record the ERROR on the test set print(f"Epo {str(e+start_ep)}, on TestSet, MSE: {val_epoch_MSE_list[0]}") # only write the 0th case of MSE add_line = str(e+start_ep)+','+str(steps)+','+str(val_epoch_MSE_list[0])+'\n' with open(train_hist_file_on_testset,'a') as f: f.write(add_line) print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") print ("II. SAMPLING FOR DE NOVO: NOT USED in predictor mode") print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") # # +++++++++++++++++++++++++++++++++++++++++ # DeNovoSam_pdbs, fasta_file_list=\ # sample_sequence_omegafold_pLM_ModelB ( # model, # x_data=test_condition_list, # flag=steps, # flag="DeNovo", # , # cond_scales=1., # foldproteins=foldproteins, # # ++++++++++ # ynormfac=ynormfac, # train_unet_number=train_unet_number, # tokenizer_X=tokenizer_X, # Xnormfac=Xnormfac, # max_length=max_length, # prefix=prefix, # tokenizer_y=tokenizer_y, # # ++ # CKeys=CKeys, # sample_dir=sample_dir, # steps=steps, # e=e+start_ep, # IF_showfig= CKeys['SlientRun']!=1, # # ++ # pLM_Model=pLM_Model, # pLM_Model_Name=pLM_Model_Name, # image_channels=image_channels, # pLM_alphabet=esm_alphabet, # ) # if print_error and len(DeNovoSam_pdbs)>0: # print("Calculate SecStr and design error:") # # # for ii in range(len(test_condition_list)): # seq=test_condition_list[ii][0] # DSSPresult,_,sequence_res=get_DSSP_result(DeNovoSam_pdbs[ii]) # print (f"INPUT: {seq}\nRESULT: {DSSPresult}\nAA sequence: {sequence_res}") # error=string_diff (DSSPresult, seq)/len (seq) # print ("Error: ", error) # every epoch: norm_loss_over_e = train_epoch_loss/len(train_loader) print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # write this into "train_hist_file_full" add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' with open(train_hist_file_full,'a') as f: f.write(add_line) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # save model every this epoches if save_model and (e+start_ep) % save_every_epoch==0 and e>1: # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" trainer.save(fname) print (f"Model saved: ", fname) # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" torch.save(model.state_dict(), fname) print (f"Statedict model saved: ", fname) # add a saving point file top_line='epoch,steps,norm_loss'+'\n' add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' with open(save_point_info_file, "w") as f: f.write(top_line) f.write(add_line) print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") def train_loop_Old_FromModelB (model, train_loader, test_loader, optimizer=None, print_every=10, epochs= 300, start_ep=0, start_step=0, train_unet_number=1, print_loss=1000, trainer=None, plot_unscaled=False, max_batch_size=4, save_model=False, cond_scales=[7.5], #list of cond scales - each sampled... num_samples=2, #how many samples produced every time tested..... foldproteins=False, cond_image=False, #use cond_images... # add some # +++++++++++++++++++++++++++ device=None, loss_list=[], prefix=None, ynormfac=1, test_condition_list=[], tokenizer_y=None, Xnormfac=1, tokenizer_X=None, max_length=1, ): if not exists (trainer): if not exists (optimizer): print ("ERROR: If trainer not used, need to provide optimizer.") if exists (trainer): print ("Trainer provided... will be used") steps=start_step loss_total=0 for e in range(1, epochs+1): start = time.time() torch.cuda.empty_cache() print ("######################################################################################") start = time.time() print ("NOW: Training epoch: ", e+start_ep) train_epoch_loss = 0 model.train() print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") for item in train_loader: X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) if exists (trainer): if cond_image==False: loss = trainer( y_train_batch.unsqueeze(1) , x=X_train_batch, unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) if cond_image==True: loss = trainer( y_train_batch.unsqueeze(1) ,x=None, cond_images=X_train_batch.unsqueeze(1), unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) trainer.update(unet_number = train_unet_number) else: optimizer.zero_grad() if cond_image==False: loss=model (y_train_batch.unsqueeze(1) , x=X_train_batch, unet_number=train_unet_number) if cond_image==True: loss=model (y_train_batch.unsqueeze(1) ,x=None, cond_images=X_train_batch.unsqueeze(1), unet_number=train_unet_number) loss.backward( ) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() loss_total=loss_total+loss.item() if steps % print_every == 0: print(".", end="") if steps>0: if steps % print_loss == 0: if plot_unscaled: #test before scaling... plt.plot (y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),label= 'Unscaled GT') plt.legend() plt.show() GT=y_train_batch.cpu().detach() GT=resize_image_to( GT.unsqueeze(1), model.imagen.image_sizes[train_unet_number-1], ) norm_loss=loss_total/print_loss print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}") loss_list.append (norm_loss) loss_total=0 plt.plot (loss_list, label='Loss') plt.legend() outname = prefix+ f"loss_{e}_{steps}.jpg" plt.savefig(outname, dpi=200) plt.show() #### num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") if cond_image == True: use_text_embedd=False else: use_text_embedd=True sample_loop_FromModelB ( model, test_loader, cond_scales=cond_scales,# #list of cond scales - each sampled... num_samples=num_samples, #how many samples produced every time tested..... timesteps=64, flag=steps, #reverse=False, foldproteins=foldproteins, use_text_embedd= use_text_embedd, # ++++++++++++++++++++ train_unet_number=train_unet_number, ynormfac=ynormfac, prefix=prefix, tokenizer_y=tokenizer_y, Xnormfac=Xnormfac, tokenizer_X=tokenizer_X, ) #index_word': '{"1": "~", "2": "h", "3": "e", "4": "s", "5": "t", "6": "g", "7": "b", "8": "i"}', #'word_index': '{"~": 1, "h": 2, "e": 3, "s": 4, "t": 5, "g": 6, "b": 7, "i": 8}'} AH_code=2/Xnormfac BS_code=3/Xnormfac unstr_code= 1/Xnormfac print ("SAMPLING FOR DE NOVO:") # +++++++++++++++++++++++++++++++++++++++++ for this_x_data in test_condition_list: sample_sequence_FromModelB ( model, x_data=this_x_data, flag=steps,cond_scales=1., foldproteins=True, # ++++++++++ ynormfac=ynormfac, train_unet_number=train_unet_number, tokenizer_X=tokenizer_X, Xnormfac=Xnormfac, max_length=max_length, prefix=prefix, tokenizer_y=tokenizer_y, ) # ----------------------------------------- # sample_sequence ( # model, # x_data=['~~~HHHHHHHHHHHHHHH~~'], # flag=steps,cond_scales=1., # foldproteins=True, # # ++++++++++ # ynormfac=ynormfac, # ) # sample_sequence ( # model, # x_data=['~~~HHHHHHHHHHHHHHH~~~~HHHHHHHHHHHHHH~~~'], # flag=steps,cond_scales=1., # foldproteins=True, # # ++++++++++ # ynormfac=ynormfac, # ) # sample_sequence ( # model, # x_data=['~~EEESSTTS~SEEEEEEEEE~SBS~EEEEEE~~'], # flag=steps,cond_scales=1., # foldproteins=True, # # ++++++++++++ # ynormfac=ynormfac, # ) if steps>0: if save_model and steps % print_loss==0: fname=f"{prefix}trainer_save-model-epoch_{e}.pt" trainer.save(fname) print (f"Model saved: ", fname) fname=f"{prefix}statedict_save-model-epoch_{e}.pt" torch.save(model.state_dict(), fname) print (f"Statedict model saved: ", fname) steps=steps+1 print (f"\n\n-------------------\nTime for epoch {e}={(time.time()-start)/60}\n-------------------") # ++++++++++++++++++++++++++++++++++++++++++++++ def foldandsavePDB_pdb_fasta ( sequence, filename_out, num_cycle=16, flag=0, # ++++++++++++ prefix=None, ): filename=f"{prefix}fasta_in_{flag}.fasta" print ("Writing FASTA file: ", filename) OUTFILE=f"{filename_out}_{flag}" with open (filename, mode ='w') as f: f.write (f'>{OUTFILE}\n') f.write (f'{sequence}') print (f"Now run OmegaFold.... on device={device}") # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device cmd_line=F"omegafold {filename} {prefix} --num_cycle {num_cycle} --device={device}" print(os.popen(cmd_line).read()) print ("Done OmegaFold") # PDB_result=f"{prefix}{OUTFILE}.PDB" PDB_result=f"{prefix}{OUTFILE}.pdb" print (f"Resulting PDB file...: {PDB_result}") return PDB_result, filename def foldandsavePDB ( sequence, filename_out, num_cycle=16, flag=0, # ++++++++++++ prefix=None, ): filename=f"{prefix}fasta_in_{flag}.fasta" print ("Writing FASTA file: ", filename) OUTFILE=f"{filename_out}_{flag}" with open (filename, mode ='w') as f: f.write (f'>{OUTFILE}\n') f.write (f'{sequence}') print (f"Now run OmegaFold.... on device={device}") # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device cmd_line=F"omegafold {filename} {prefix} --num_cycle {num_cycle} --device={device}" print(os.popen(cmd_line).read()) print ("Done OmegaFold") # PDB_result=f"{prefix}{OUTFILE}.PDB" PDB_result=f"{prefix}{OUTFILE}.pdb" print (f"Resulting PDB file...: {PDB_result}") return PDB_result import py3Dmol def plot_plddt_legend(dpi=100): thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)'] plt.figure(figsize=(1,0.1),dpi=dpi) ######################################## for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]: plt.bar(0, 0, color=c) plt.legend(thresh, frameon=False, loc='center', ncol=6, handletextpad=1, columnspacing=1, markerscale=0.5,) plt.axis(False) return plt color = "lDDT" # choose from ["chain", "lDDT", "rainbow"] show_sidechains = False #choose from {type:"boolean"} show_mainchains = False #choose from {type:"boolean"} def show_pdb(pdb_file, flag=0, show_sidechains=False, show_mainchains=False, color="lDDT"): model_name = f"Flag_{flag}" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',) view.addModel(open(pdb_file,'r').read(),'pdb') if color == "lDDT": view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}}) elif color == "rainbow": view.setStyle({'cartoon': {'color':'spectrum'}}) elif color == "chain": chains = len(queries[0][1]) + 1 if is_complex else 1 for n,chain,color in zip(range(chains),list("ABCDEFGH"), ["lime","cyan","magenta","yellow","salmon","white","blue","orange"]): view.setStyle({'chain':chain},{'cartoon': {'color':color}}) if show_sidechains: BB = ['C','O','N'] view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]}, {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]}, {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]}, {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) if show_mainchains: BB = ['C','O','N','CA'] view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) view.zoomTo() if color == "lDDT": plot_plddt_legend().show() return view def get_avg_Bfac (file='./output_v3/[0.0 0.5 0.0 0.0 0.0 0.0 0.0 0.0].pdb'): p = PDBParser() avg_B=0 bfac_list=[] structure = p.get_structure("X", file) for PDBmodel in structure: for chain in PDBmodel: for residue in chain: for atom in residue: Bfac=atom.get_bfactor() bfac_list.append(Bfac) avg_B=avg_B+Bfac avg_B=avg_B/len (bfac_list) print (f"For {file}, average B-factor={avg_B}") plt.plot (bfac_list, label='lDDT') plt.xlabel ('Atom #' ) plt.ylabel ('iDDT') plt.legend() plt.show() return avg_B, bfac_list def sample_sequence_normalized_Bfac (seccs=[0.3, 0.3, 0.1, 0., 0., 0., 0., 0. ]): sample_numbers=torch.tensor([seccs]) sample_numbers=torch.nn.functional.normalize (sample_numbers, dim=1) sample_numbers=sample_numbers/torch.sum(sample_numbers) PDB=sample_sequence (model, X=sample_numbers, flag=0,cond_scales=1, foldproteins=True, ) avg,_ = get_avg_Bfac (file=PDB[0]) return PDB, avg # ====================================================== # blocks for Model A # ====================================================== def train_loop_Old_FromModelA ( model, train_loader, test_loader, # optimizer=None, print_every=1, epochs= 300, start_ep=0, start_step=0, train_unet_number=1, print_loss_every_steps=1000, # trainer=None, plot_unscaled=False, max_batch_size=4, save_model=False, cond_scales=[1.0], #list of cond scales num_samples=2, #how many samples produced every time tested..... foldproteins=False, # ++ cond_image=False, # not use cond_images... for model A cond_text=True, # use condi_text... for model A # + device=None, loss_list=[], epoch_list=[], train_hist_file=None, train_hist_file_full=None, prefix=None, # not used in this function Xnormfac=None, ynormfac=1., tokenizer_X=None, tokenizer_y=None, test_condition_list=[], max_length_Y=1, max_text_len_X=1, CKeys=None, sample_steps=1, sample_dir=None, save_every_epoch=1, save_point_info_file=None, store_dir=None, ): # #+ # Xnormfac=Xnormfac.to(model.device) if not exists (trainer): if not exists (optimizer): print ("ERROR: If trainer not used, need to provide optimizer.") if exists (trainer): print ("Trainer provided... will be used") # -------------------------------- # steps=start_step # ++++++++++++++++++++++++++++++++ steps=start_step added_steps=0 loss_total=0 for e in range(1, epochs+1): # start = time.time() torch.cuda.empty_cache() print ("######################################################################################") start = time.time() print ("NOW: Training epoch: ", e+start_ep) # TRAINING train_epoch_loss = 0 model.train() print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") for item in train_loader: # ++ steps += 1 added_steps += 1 X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) if exists (trainer): if cond_image==False: # ======================================== # Model A: condition via text # ======================================== # this block depends on the model:forward loss = trainer( # # -------------------------------- # X_train_batch, # y_train_batch.unsqueeze(1) , # ++++++++++++++++++++++++++++++++ y_train_batch.unsqueeze(1) , x=X_train_batch, # unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) if cond_image==True: # ======================================== # Model B: condition via image/sequence # ======================================== # added for future: Train_loop B loss = trainer( y_train_batch.unsqueeze(1) , x=None, cond_images=X_train_batch.unsqueeze(1), unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) # pass # trainer.update(unet_number = train_unet_number) else: optimizer.zero_grad() if cond_image==False: # this block depends on the model:forward loss=model ( # # -------------------------------- # X_train_batch, # y_train_batch.unsqueeze(1) , # ++++++++++++++++++++++++++++++++ y_train_batch.unsqueeze(1) , x=X_train_batch, # unet_number=train_unet_number ) if cond_image==True: # added for future: Train_loop B loss=model ( y_train_batch.unsqueeze(1) , x=None, cond_images=X_train_batch.unsqueeze(1), unet_number=train_unet_number ) # loss.backward( ) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() loss_total=loss_total+loss.item() # + train_epoch_loss=train_epoch_loss+loss.item() if steps % print_every == 0: # for progress bar print(".", end="") # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ # record loss block # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ # if steps>0: if added_steps>0: if steps % print_loss_every_steps == 0: # + for debug if CKeys['Debug_TrainerPack']==2: print("Here is step: ", steps) norm_loss=loss_total/print_loss_every_steps print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # add a line to the hist file add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' with open(train_hist_file,'a') as f: f.write(add_line) loss_list.append (norm_loss) loss_total=0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ epoch_list.append(e+start_ep) fig = plt.figure() plt.plot (epoch_list, loss_list, label='Loss') plt.legend() # outname = prefix+ f"loss_{e}_{steps}.jpg" outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" # # the order, save then show, matters if CKeys['SlientRun']==1: plt.savefig(outname, dpi=200) else: plt.show() plt.close(fig) # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ # sample in test set block # set sample_steps < 0 to switch off this block # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ # if steps>0: if added_steps>0: if steps % sample_steps == 0 and sample_steps > 0: # + for debug if CKeys['Debug_TrainerPack']==2: print("Here is steps: ", steps) if plot_unscaled: # test before scaling... plt.plot ( y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), label= 'Unscaled GT' ) plt.legend() plt.show() #rescale GT to properly plot GT=y_train_batch.cpu().detach() GT=resize_image_to( GT.unsqueeze(1), model.imagen.image_sizes[train_unet_number-1], ) #### print ("I. SAMPLING IN TEST SET: ") #### num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") sample_loop_omegafold_ModelA ( model, test_loader, cond_scales=cond_scales, num_samples=num_samples, #how many samples produced every time tested..... timesteps=None, flag=e+start_ep, # steps, foldproteins=foldproteins, # add condi_key cond_image=cond_image, # Not used for now cond_text=cond_text, # Not used for now skip_steps=0, # max_text_len=max_text_len_X, max_length=max_length_Y, # ++++++++++++++++++++ train_unet_number=train_unet_number, ynormfac=ynormfac, prefix=prefix, # tokenizer_y=tokenizer_y, Xnormfac_CondiText=Xnormfac, tokenizer_X_CondiText=tokenizer_X, # ++ CKeys=CKeys, sample_dir=sample_dir, steps=steps, e=e+start_ep, IF_showfig= CKeys['SlientRun']!=1 , ) print ("II. SAMPLING FOR DE NOVO:") sample_sequence_omegafold_ModelA ( # # ---------------------------------------------- # model, # X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]], # foldproteins=foldproteins, # flag=steps,cond_scales=1., # ++++++++++++++++++++++++++++++++++++++++++++++ model, X=test_condition_list, # [[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X flag=e+start_ep, # steps, # 0, cond_scales=cond_scales, # 1., foldproteins=True, # False, X_string=None, # from text conditioning X_string x_data=None, # from image conditioning x_data skip_steps=0, inpaint_images=None, # in formation Y data inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # for omegafolding calc_error=False, # for check on folded results, not used for every case # ++++++++++++++++++++++++++ # tokenizers tokenizer_X_forImageCondi=None, # for x_data Xnormfac_forImageCondi=1., tokenizer_X_forTextCondi=None, # for X if NEEDED only Xnormfac_forTextCondi=1., tokenizer_y=tokenizer_y, # None, # for output Y ynormfac=ynormfac, # length train_unet_number=1, max_length_Y=max_length_Y, # for Y, X_forImageCondi max_text_len=max_text_len_X, # for X_forTextCondi # other info steps=steps, # None, e=e, # None, sample_dir=sample_dir, # None, prefix=prefix, # None, IF_showfig= CKeys['SlientRun']!=1, # True, CKeys=CKeys, # TBA to Model B normalize_X_cond_to_one=False, ) # sample_sequence (model, # X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins, # flag=steps,cond_scales=1., # ) # summerize loss over every epoch: norm_loss_over_e = train_epoch_loss/len(train_loader) print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # write this into "train_hist_file_full" add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' with open(train_hist_file_full,'a') as f: f.write(add_line) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # save model every this epoches if save_model and (e+start_ep) % save_every_epoch==0 and e>1: # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" trainer.save(fname) print (f"Model saved: ", fname) # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" torch.save(model.state_dict(), fname) print (f"Statedict model saved: ", fname) # add a saving point file top_line='epoch,steps,norm_loss'+'\n' add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' with open(save_point_info_file, "w") as f: f.write(top_line) f.write(add_line) # if steps>0: # if save_model and steps % print_loss_every_steps==0: # fname=f"{prefix}trainer_save-model-epoch_{e}.pt" # trainer.save(fname) # fname=f"{prefix}statedict_save-model-epoch_{e}.pt" # torch.save(model.state_dict(), fname) # print (f"Model saved: ") # steps=steps+1 print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") def train_loop_ForModelA_II ( model, train_loader, test_loader, # optimizer=None, print_every=1, epochs= 300, start_ep=0, start_step=0, train_unet_number=1, print_loss_every_steps=1000, # trainer=None, plot_unscaled=False, max_batch_size=4, save_model=False, cond_scales=[1.0], #list of cond scales num_samples=2, #how many samples produced every time tested..... foldproteins=False, # ++ cond_image=False, # not use cond_images... for model A cond_text=True, # use condi_text... for model A # + device=None, loss_list=[], epoch_list=[], train_hist_file=None, train_hist_file_full=None, prefix=None, # not used in this function Xnormfac=None, ynormfac=1., tokenizer_X=None, tokenizer_y=None, test_condition_list=[], max_length_Y=1, max_text_len_X=1, CKeys=None, sample_steps=1, sample_dir=None, save_every_epoch=1, save_point_info_file=None, store_dir=None, # ++ for pLM pLM_Model_Name=None, image_channels=None, print_error=False, # not defined for Problem6 # True, ): # #+ # Xnormfac=Xnormfac.to(model.device) if not exists (trainer): if not exists (optimizer): print ("ERROR: If trainer not used, need to provide optimizer.") if exists (trainer): print ("Trainer provided... will be used") # -------------------------------- # steps=start_step # ++++++++++++++++++++++++++++++++ steps=start_step added_steps=0 loss_total=0 # ++ for pLM if pLM_Model_Name=='None': pLM_Model=None elif pLM_Model_Name=='esm2_t33_650M_UR50D': # dim: 1280 esm_layer=33 pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() len_toks=len(esm_alphabet.all_toks) pLM_Model.eval() pLM_Model. to(device) elif pLM_Model_Name=='esm2_t36_3B_UR50D': # dim: 2560 esm_layer=36 pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() len_toks=len(esm_alphabet.all_toks) pLM_Model.eval() pLM_Model. to(device) elif pLM_Model_Name=='esm2_t30_150M_UR50D': # dim: 640 esm_layer=30 pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() len_toks=len(esm_alphabet.all_toks) pLM_Model.eval() pLM_Model. to(device) elif pLM_Model_Name=='esm2_t12_35M_UR50D': # dim: 480 esm_layer=12 pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() len_toks=len(esm_alphabet.all_toks) pLM_Model.eval() pLM_Model. to(device) else: print("pLM model is missing...") for e in range(1, epochs+1): # start = time.time() torch.cuda.empty_cache() print ("######################################################################################") start = time.time() print ("NOW: Training epoch: ", e+start_ep) # TRAINING train_epoch_loss = 0 model.train() print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") for item in train_loader: # ++ steps += 1 added_steps += 1 X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) # project y_ into embedding space if CKeys["Debug_TrainerPack"]==1: print("Initial unload the dataloader items: ...") print(X_train_batch.shape) print(y_train_batch.shape) # ++ # project the AA seq into embedding space # for output, it is shared between ModelA and ModelB # # -- # if pLM_Model_Name=='None': # # just use the encoded sequence # y_train_batch_in = y_train_batch.unsqueeze(1) # # pass # elif pLM_Model_Name=='esm2_t33_650M_UR50D': # with torch.no_grad(): # results = pLM_Model( # y_train_batch, # repr_layers=[33], # return_contacts=False, # ) # y_train_batch_in = results["representations"][33] # y_train_batch_in = rearrange( # y_train_batch_in, # 'b l c -> b c l' # ) # else: # print(f"Required pLM name is not defined!!") # ++ if pLM_Model_Name=='None': # just use the encoded sequence y_train_batch_in = y_train_batch.unsqueeze(1) # pass else: # for ESM models # pLM_Model_Name=='esm2_t33_650M_UR50D': with torch.no_grad(): results = pLM_Model( y_train_batch, repr_layers=[esm_layer], return_contacts=False, ) y_train_batch_in = results["representations"][esm_layer] y_train_batch_in = rearrange( y_train_batch_in, 'b l c -> b c l' ) # # For input part, this block is different for ModelA and ModelB if cond_image==False: # model A: X: text_condi, not affected by pLM X_train_batch_in = X_train_batch else: # model B: X: cond_img, will be affected by pLM X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) # # + for debug if CKeys["Debug_TrainerPack"]==1: print("After pLM model, the shape of X and y for training:") print("X_train_batch_in.dim: ", X_train_batch_in.shape) print("y_train_batch_in.dim: ", y_train_batch_in.shape) if exists (trainer): if cond_image==False: # ======================================== # Model A: condition via text # ======================================== # this block depends on the model:forward loss = trainer( # # -------------------------------- # X_train_batch, # y_train_batch.unsqueeze(1) , # # ++++++++++++++++++++++++++++++++ # y_train_batch.unsqueeze(1) , # x=X_train_batch, # ++ pLM y_train_batch_in, x=X_train_batch_in, # unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) if cond_image==True: # ======================================== # Model B: condition via image/sequence # ======================================== # # -- # # added for future: Train_loop B # loss = trainer( # y_train_batch.unsqueeze(1) , # x=None, # cond_images=X_train_batch.unsqueeze(1), # unet_number=train_unet_number, # max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory # ) # ++ from pLM+ModelB loss = trainer( y_train_batch_in, # true image x=None, # tokenized text cond_images=X_train_batch_in, # cond_image unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) # pass # trainer.update(unet_number = train_unet_number) else: optimizer.zero_grad() if cond_image==False: # this block depends on the model:forward loss=model ( # # -------------------------------- # X_train_batch, # y_train_batch.unsqueeze(1) , # # ++++++++++++++++++++++++++++++++ # y_train_batch.unsqueeze(1) , # x=X_train_batch, # ++ pLM y_train_batch_in, x=X_train_batch_in, # unet_number=train_unet_number ) if cond_image==True: # added for future: Train_loop B # # -- # loss=model ( # y_train_batch.unsqueeze(1) , # x=None, # cond_images=X_train_batch.unsqueeze(1), # unet_number=train_unet_number # ) # ++ from pLM loss=model ( y_train_batch_in , x=None, cond_images=X_train_batch_in, unet_number=train_unet_number ) # loss.backward( ) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() loss_total=loss_total+loss.item() # + train_epoch_loss=train_epoch_loss+loss.item() if steps % print_every == 0: # for progress bar print(".", end="") # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ # record loss block # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ # if steps>0: if added_steps>0: if steps % print_loss_every_steps == 0: # + for debug if CKeys['Debug_TrainerPack']==2: print("Here is step: ", steps) norm_loss=loss_total/print_loss_every_steps print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # add a line to the hist file add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' with open(train_hist_file,'a') as f: f.write(add_line) loss_list.append (norm_loss) loss_total=0 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ epoch_list.append(e+start_ep) fig = plt.figure() plt.plot (epoch_list, loss_list, label='Loss') plt.legend() # outname = prefix+ f"loss_{e}_{steps}.jpg" outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" # # the order, save then show, matters if CKeys['SlientRun']==1: plt.savefig(outname, dpi=200) else: plt.show() plt.close(fig) # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ # sample in test set block # set sample_steps < 0 to switch off this block # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ # if steps>0: if added_steps>0: if steps % sample_steps == 0 and sample_steps > 0: # + for debug if CKeys['Debug_TrainerPack']==2: print("Here is steps: ", steps) if plot_unscaled: # test before scaling... plt.plot ( y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), label= 'Unscaled GT' ) plt.legend() plt.show() # # -- look like not used # #rescale GT to properly plot # GT=y_train_batch.cpu().detach() # GT=resize_image_to( # GT.unsqueeze(1), # model.imagen.image_sizes[train_unet_number-1], # ) #### print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ") print ("I. SAMPLING IN TEST SET: ") print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") #### num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") sample_loop_omegafold_pLM_ModelA ( model, test_loader, cond_scales=cond_scales, num_samples=num_samples, #how many samples produced every time tested..... timesteps=None, flag=e+start_ep, # steps, foldproteins=foldproteins, # add condi_key cond_image=cond_image, # Not used for now cond_text=cond_text, # Not used for now skip_steps=0, # max_text_len=max_text_len_X, max_length=max_length_Y, # ++++++++++++++++++++ train_unet_number=train_unet_number, ynormfac=ynormfac, prefix=prefix, # tokenizer_y=tokenizer_y, Xnormfac_CondiText=Xnormfac, tokenizer_X_CondiText=tokenizer_X, # ++ CKeys=CKeys, sample_dir=sample_dir, steps=steps, e=e+start_ep, IF_showfig= CKeys['SlientRun']!=1 , # ++ for pLM pLM_Model=pLM_Model, pLM_Model_Name=pLM_Model_Name, image_channels=image_channels, pLM_alphabet=esm_alphabet, ) print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ") print ("II. SAMPLING FOR DE NOVO:") print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") DeNovoSam_pdbs, fasta_file_list=\ sample_sequence_omegafold_pLM_ModelA ( # # ---------------------------------------------- # model, # X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]], # foldproteins=foldproteins, # flag=steps,cond_scales=1., # ++++++++++++++++++++++++++++++++++++++++++++++ model, X=test_condition_list, # [[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X flag=e+start_ep, # steps, # 0, cond_scales=cond_scales, # 1., foldproteins=True, # False, X_string=None, # from text conditioning X_string x_data=None, # from image conditioning x_data skip_steps=0, inpaint_images=None, # in formation Y data inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # for omegafolding calc_error=False, # for check on folded results, not used for every case # ++++++++++++++++++++++++++ # tokenizers tokenizer_X_forImageCondi=None, # for x_data Xnormfac_forImageCondi=1., tokenizer_X_forTextCondi=None, # for X if NEEDED only Xnormfac_forTextCondi=1., tokenizer_y=tokenizer_y, # None, # for output Y ynormfac=ynormfac, # length train_unet_number=1, max_length_Y=max_length_Y, # for Y, X_forImageCondi max_text_len=max_text_len_X, # for X_forTextCondi # other info steps=steps, # None, e=e, # None, sample_dir=sample_dir, # None, prefix=prefix, # None, IF_showfig= CKeys['SlientRun']!=1, # True, CKeys=CKeys, # TBA to Model B normalize_X_cond_to_one=False, # ++ for pLM pLM_Model=pLM_Model, pLM_Model_Name=pLM_Model_Name, image_channels=image_channels, pLM_alphabet=esm_alphabet, ) # sample_sequence (model, # X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins, # flag=steps,cond_scales=1., # ) # summerize loss over every epoch: norm_loss_over_e = train_epoch_loss/len(train_loader) print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # write this into "train_hist_file_full" add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' with open(train_hist_file_full,'a') as f: f.write(add_line) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # save model every this epoches if save_model and (e+start_ep) % save_every_epoch==0 and e>1: # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" trainer.save(fname) print (f"Model saved: ", fname) # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" torch.save(model.state_dict(), fname) print (f"Statedict model saved: ", fname) # add a saving point file top_line='epoch,steps,norm_loss'+'\n' add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' with open(save_point_info_file, "w") as f: f.write(top_line) f.write(add_line) # if steps>0: # if save_model and steps % print_loss_every_steps==0: # fname=f"{prefix}trainer_save-model-epoch_{e}.pt" # trainer.save(fname) # fname=f"{prefix}statedict_save-model-epoch_{e}.pt" # torch.save(model.state_dict(), fname) # print (f"Model saved: ") # steps=steps+1 print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") # from original, not used any more def train_loop_Model_A ( model, train_loader, test_loader, optimizer=None, print_every=10, epochs= 300, start_ep=0, start_step=0, train_unet_number=1, print_loss=1000, trainer=None, plot_unscaled=False, max_batch_size=4, save_model=False, cond_scales=[1.0], #list of cond scales num_samples=2, #how many samples produced every time tested..... foldproteins=False, ): if not exists (trainer): if not exists (optimizer): print ("ERROR: If trainer not used, need to provide optimizer.") if exists (trainer): print ("Trainer provided... will be used") steps=start_step loss_total=0 for e in range(1, epochs+1): start = time.time() torch.cuda.empty_cache() print ("######################################################################################") start = time.time() print ("NOW: Training epoch: ", e+start_ep) # TRAINING train_epoch_loss = 0 model.train() print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") for item in train_loader: X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) if exists (trainer): loss = trainer( X_train_batch, y_train_batch.unsqueeze(1) , unet_number=train_unet_number, max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory ) trainer.update(unet_number = train_unet_number) else: optimizer.zero_grad() loss=model ( X_train_batch, y_train_batch.unsqueeze(1) ,unet_number=train_unet_number) loss.backward( ) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() loss_total=loss_total+loss.item() if steps % print_every == 0: print(".", end="") if steps>0: if steps % print_loss == 0: if plot_unscaled: plt.plot (y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),label= 'Unscaled GT') plt.legend() plt.show() #rescale GT to properly plot GT=y_train_batch.cpu().detach() GT=resize_image_to( GT.unsqueeze(1), model.imagen.image_sizes[train_unet_number-1], ) norm_loss=loss_total/print_loss print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}") loss_list.append (norm_loss) loss_total=0 plt.plot (loss_list, label='Loss') plt.legend() outname = prefix+ f"loss_{e}_{steps}.jpg" plt.savefig(outname, dpi=200) plt.show() num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") sample_loop (model, test_loader, cond_scales=cond_scales, num_samples=1, #how many samples produced every time tested..... timesteps=64, flag=steps,foldproteins=foldproteins, ) print ("SAMPLING FOR DE NOVO:") sample_sequence (model, X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]],foldproteins=foldproteins, flag=steps,cond_scales=1., ) sample_sequence (model, X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins, flag=steps,cond_scales=1., ) if steps>0: if save_model and steps % print_loss==0: fname=f"{prefix}trainer_save-model-epoch_{e}.pt" trainer.save(fname) fname=f"{prefix}statedict_save-model-epoch_{e}.pt" torch.save(model.state_dict(), fname) print (f"Model saved: ") steps=steps+1 print (f"\n\n-------------------\nTime for epoch {e}={(time.time()-start)/60}\n-------------------") # +++ def sample_sequence_omegafold_ModelA ( model, X=[[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X flag=0, cond_scales=1., foldproteins=False, X_string=None, # from text conditioning X_string x_data=None, # from image conditioning x_data skip_steps=0, inpaint_images=None, # in formation Y data inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # for omegafolding calc_error=False, # for check on folded results, not used for every case # ++++++++++++++++++++++++++ # tokenizers tokenizer_X_forImageCondi=None, # for x_data Xnormfac_forImageCondi=1., tokenizer_X_forTextCondi=None, # for X if NEEDED only Xnormfac_forTextCondi=1., tokenizer_y=None, # for output Y ynormfac=1, # length train_unet_number=1, max_length_Y=1, # for Y, X_forImageCondi max_text_len=1, # for X_forTextCondi # other info steps=None, e=None, sample_dir=None, prefix=None, IF_showfig=True, CKeys=None, # TBA to Model B normalize_X_cond_to_one=False, ): # ----------- # steps=0 # e=flag # -- # print (f"Producing {len(X)} samples...") # ++ if X!=None: print (f"Producing {len(X)} samples...from text conditioning X...") lenn_val=len(X) if X_string!=None: lenn_val=len(X_string) print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") if x_data!=None: print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") lenn_val=len(x_data) # print (x_data) print ('Device: ', model.device) for iisample in range (lenn_val): print(f"Working on {iisample}") X_cond=None if X_string==None and X!=None: # for X channel X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) if X_string!=None: # from raw text, ie., X_string: need tokenizer_X and Xnormfac # - # X = tokenizer_X.texts_to_sequences(X_string[iisample]) # X= sequence.pad_sequences(X, maxlen=max_length, padding='post', truncating='post') # X=np.array(X) # X_cond=torch.from_numpy(X).float()/Xnormfac # + XX = tokenizer_X_forTextCondi.texts_to_sequences(X_string[iisample]) XX = sequence.pad_sequences(XX, maxlen=max_text_len, padding='post', truncating='post') XX = np.array(XX) X_cond = torch.from_numpy(XX).float()/Xnormfac_forTextCondi print ('Tokenized and processed: ', X_cond) if X_cond!=None: if normalize_X_cond_to_one: # used when there is constrain on X_cond.sum() X_cond=X_cond/X_cond.sum() print ("Text conditoning used: ", X_cond, "...sum: ", X_cond.sum(), "cond scale: ", cond_scales) else: print ("Text conditioning used: None") # for now, assume image_condi and text_condi can be used at the same time if tokenizer_X_forImageCondi==None: # =========================================================== # condi_image/seq needs no tokenization, like numbers: force_path # only normalization needed # Based on ModelB:Force_Path if x_data!=None: x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac_forImageCondi) x_data_tokenized=x_data_tokenized.to(torch.float) # + for debug: if CKeys['Debug_TrainerPack']==1: print("x_data_tokenized dim: ", x_data_tokenized.shape) print("x_data_tokenized dtype: ", x_data_tokenized.dtype) print("test: ", x_data_tokenized!=None) else: x_data_tokenized=None # + for debug: if CKeys['Debug_TrainerPack']==1: print("x_data_tokenized and x_data: None") # model.sample:full arguments # self, # x=None, # stop_at_unet_number=1, # cond_scale=7.5, # # ++ # x_data=None, # image_condi data # skip_steps=None, # inpaint_images = None, # inpaint_masks = None, # inpaint_resample_times = 5, # init_images = None, # x_data_tokenized=None, # tokenizer_X=None, # Xnormfac=1., # # -+ # device=None, # max_length=1., # for XandY data, in image/sequence format; NOT for text condition # max_text_len=1., # for X data, in text format # result=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=None, # ++ x_data_tokenized=x_data_tokenized, # skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images, device=model.device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X, Xnormfac=Xnormfac_forImageCondi, # Xnormfac, # ynormfac=ynormfac, max_length=max_length_Y, # for ImageCondi, max_length, max_text_len=max_text_len, ) else: # result=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=x_data[iisample], # ++ x_data_tokenized=None, # skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images, device=model.device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X, Xnormfac=Xnormfac_forImageCondi, # Xnormfac, # ynormfac=ynormfac, max_length=max_length_Y, # max_length, max_text_len=max_text_len, ) # # ------------------------------------------ # result=model.sample ( # X_cond, # stop_at_unet_number=train_unet_number, # cond_scale=cond_scales # ) result=torch.round(result*ynormfac) # + for debug print("result.dim: ", result.shape) fig=plt.figure() plt.plot ( result[0,0,:].cpu().detach().numpy(), label= f'Predicted' ) #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') plt.legend() outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") if IF_showfig==1: plt.show () else: plt.savefig(outname, dpi=200) plt.close() # # ---------------------------------------- # plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted') # plt.legend() # outname = prefix+ f"sampld_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" # plt.savefig(outname, dpi=200) # plt.show () to_rev=result[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() print("to_rev.dim: ", to_rev.shape) y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") # + from Model B ### reverse second structure input.... pdb_list=[] if X_cond != None: # there is condi_text if X_string!=None: X_cond=torch.round(X_cond*Xnormfac_forTextCondi) to_rev=X_cond[:,:] to_rev=to_rev.long().cpu().detach().numpy() print ("to_rev.dim: ", to_rev.shape) # -- # X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) # ++ X_text_reversed=tokenizer_X_forTextCondi.sequences_to_texts (to_rev) for iii in range (len(y_text_reversed)): X_text_reversed[iii]=X_text_reversed[iii].upper().strip().replace(" ", "") if X_string==None: # reverse this: X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) X_text_reversed=X_cond else: X_text_reversed=None if x_data !=None: # there is condi_image x_data_reversed=x_data #is already in sequence fromat.. else: x_data_reversed=None # summary # print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,\n predicted sequence: ", y_data_reversed) print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,") print (f"predicted sequence full: {y_data_reversed}") # add just for incase check print (f"predicted sequence: {y_data_reversed[0]}") # + for debug print("================================================") print("foldproteins: ", foldproteins) if not foldproteins: pdb_file=None else: # if foldproteins: if X_cond != None: pass tempname='temp' pdb_file=foldandsavePDB ( sequence=y_data_reversed[0], filename_out=tempname, num_cycle=num_cycle, flag=flag, # +++++++++++++++++++ # prefix=prefix, prefix=sample_dir, ) # out_nam=iisample # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' write_fasta (y_data_reversed[0], out_nam_fasta) # out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' shutil.copy (pdb_file, out_nam) #source, dest pdb_file=out_nam # print (f"Properly named PDB file produced: {pdb_file}") if IF_showfig==1: #flag=1000 view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() if calc_error: # only work for ModelA:SecStr if CKeys['Problem_ID']==7: get_Model_A_error (pdb_file, X[iisample], plotit=True) else: print ("Error calculation on the predicted results is not applicable") pdb_list.append(pdb_file) return pdb_list # xbc=X_cond[iisample,:].cpu().detach().numpy() # out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.2f" % xbc})+f'_{flag}_{steps}' # tempname='temp' # pdb_file=foldandsavePDB (sequence=y_data_reversed[0], # filename_out=tempname, # num_cycle=16, flag=flag) # out_nam_fasta=f'{prefix}{out_nam}.fasta' # out_nam=f'{prefix}{out_nam}.pdb' # write_fasta (y_data_reversed[0], out_nam_fasta) # shutil.copy (pdb_file, out_nam) #source, dest # pdb_file=out_nam # print (f"Properly named PDB file produced: {pdb_file}") # view=show_pdb(pdb_file=pdb_file, flag=flag, # show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) # view.show() # if calc_error: # get_Model_A_error (pdb_file, X[iisample], plotit=True) # return pdb_file # +++ def sample_sequence_omegafold_pLM_ModelA ( model, X=[[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X flag=0, cond_scales=1., foldproteins=False, X_string=None, # from text conditioning X_string x_data=None, # from image conditioning x_data skip_steps=0, inpaint_images=None, # in formation Y data inpaint_masks = None, inpaint_resample_times = None, init_images = None, num_cycle=16, # for omegafolding calc_error=False, # for check on folded results, not used for every case # ++++++++++++++++++++++++++ # tokenizers tokenizer_X_forImageCondi=None, # for x_data Xnormfac_forImageCondi=1., tokenizer_X_forTextCondi=None, # for X if NEEDED only Xnormfac_forTextCondi=1., tokenizer_y=None, # for output Y ynormfac=1, # length train_unet_number=1, max_length_Y=1, # for Y, X_forImageCondi max_text_len=1, # for X_forTextCondi # other info steps=None, e=None, sample_dir=None, prefix=None, IF_showfig=True, CKeys=None, # TBA to Model B normalize_X_cond_to_one=False, # ++ pLM_Model=None, # pLM_Model, pLM_Model_Name=None, # pLM_Model_Name, image_channels=None, # image_channels, pLM_alphabet=None, # esm_alphabet, ): # ----------- # steps=0 # e=flag # -- # print (f"Producing {len(X)} samples...") # ++ if X!=None: print (f"Producing {len(X)} samples...from text conditioning X...") lenn_val=len(X) if X_string!=None: lenn_val=len(X_string) print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") if x_data!=None: print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") lenn_val=len(x_data) # print (x_data) print ('Device: ', model.device) pdb_list=[] fasta_list=[] for iisample in range (lenn_val): print(f"Working on {iisample}") X_cond=None if X_string==None and X!=None: # for X channel X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) if X_string!=None: # from raw text, ie., X_string: need tokenizer_X and Xnormfac # - # X = tokenizer_X.texts_to_sequences(X_string[iisample]) # X= sequence.pad_sequences(X, maxlen=max_length, padding='post', truncating='post') # X=np.array(X) # X_cond=torch.from_numpy(X).float()/Xnormfac # + XX = tokenizer_X_forTextCondi.texts_to_sequences(X_string[iisample]) XX = sequence.pad_sequences(XX, maxlen=max_text_len, padding='post', truncating='post') XX = np.array(XX) X_cond = torch.from_numpy(XX).float()/Xnormfac_forTextCondi print ('Tokenized and processed: ', X_cond) if X_cond!=None: if normalize_X_cond_to_one: # used when there is constrain on X_cond.sum() X_cond=X_cond/X_cond.sum() print ("Text conditoning used: ", X_cond, "...sum: ", X_cond.sum(), "cond scale: ", cond_scales) else: print ("Text conditioning used: None") # for now, assume image_condi and text_condi can be used at the same time if tokenizer_X_forImageCondi==None: # =========================================================== # condi_image/seq needs no tokenization, like numbers: force_path # only normalization needed # Based on ModelB:Force_Path if x_data!=None: x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac_forImageCondi) x_data_tokenized=x_data_tokenized.to(torch.float) # + for debug: if CKeys['Debug_TrainerPack']==1: print("x_data_tokenized dim: ", x_data_tokenized.shape) print("x_data_tokenized dtype: ", x_data_tokenized.dtype) print("test: ", x_data_tokenized!=None) else: x_data_tokenized=None # + for debug: if CKeys['Debug_TrainerPack']==1: print("x_data_tokenized and x_data: None") # model.sample:full arguments # self, # x=None, # stop_at_unet_number=1, # cond_scale=7.5, # # ++ # x_data=None, # image_condi data # skip_steps=None, # inpaint_images = None, # inpaint_masks = None, # inpaint_resample_times = 5, # init_images = None, # x_data_tokenized=None, # tokenizer_X=None, # Xnormfac=1., # # -+ # device=None, # max_length=1., # for XandY data, in image/sequence format; NOT for text condition # max_text_len=1., # for X data, in text format # result_embedding=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=None, # ++ x_data_tokenized=x_data_tokenized, # skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images, device=model.device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X, Xnormfac=Xnormfac_forImageCondi, # Xnormfac, # ynormfac=ynormfac, max_length=max_length_Y, # for ImageCondi, max_length, max_text_len=max_text_len, ) else: # this is for model B in the future # two channels should be provided: raw cond_img+img_tokenizer or tokenized_cond_img # need to BE UPDATE and merge with code from # fun.sample_sequence_omegafold_pLM_ModelB # one branch is currently missing result_embedding=model.sample ( x=X_cond, stop_at_unet_number=train_unet_number , cond_scale=cond_scales , x_data=x_data[iisample], # ++ x_data_tokenized=None, # skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images, device=model.device, # ++++++++++++++++++++++++++ tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X, Xnormfac=Xnormfac_forImageCondi, # Xnormfac, # ynormfac=ynormfac, max_length=max_length_Y, # max_length, max_text_len=max_text_len, ) # # ------------------------------------------ # result=model.sample ( # X_cond, # stop_at_unet_number=train_unet_number, # cond_scale=cond_scales # ) # # ----------------------------------------------- # result=torch.round(result*ynormfac) # +++++++++++++++++++++++++++++++++++++++++++++++ # ++ for pLM # full record # result_embedding as image.dim: [batch, channels, seq_len] # result_tokens.dim: [batch, seq_len] result_tokens,result_logits = convert_into_tokens( pLM_Model, result_embedding, pLM_Model_Name, ) result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] # + for debug print("result.dim: ", result.shape) fig=plt.figure() plt.plot ( result[0,0,:].cpu().detach().numpy(), label= f'Predicted' ) #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') plt.legend() outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") if IF_showfig==1: plt.show () else: plt.savefig(outname, dpi=200) plt.close() # # ---------------------------------------- # plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted') # plt.legend() # outname = prefix+ f"sampld_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" # plt.savefig(outname, dpi=200) # plt.show () # # --------------------------------------------------------- # to_rev=result[:,0,:] # to_rev=to_rev.long().cpu().detach().numpy() # print("to_rev.dim: ", to_rev.shape) # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) # for iii in range (len(y_data_reversed)): # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ to_rev=result[:,0,:] # the following fun decides ending automatically y_data_reversed=decode_many_ems_token_rec_for_folding( to_rev, result_logits, pLM_alphabet, pLM_Model, ) if CKeys['Debug_TrainerPack']==3: print("on y_data_reversed[0]: ", y_data_reversed[0]) # + from Model B ### reverse second structure input.... # if X_cond != None: # there is condi_text if X_string!=None: X_cond=torch.round(X_cond*Xnormfac_forTextCondi) to_rev=X_cond[:,:] to_rev=to_rev.long().cpu().detach().numpy() print ("to_rev.dim: ", to_rev.shape) # -- # X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) # ++ X_text_reversed=tokenizer_X_forTextCondi.sequences_to_texts (to_rev) for iii in range (len(y_text_reversed)): X_text_reversed[iii]=X_text_reversed[iii].upper().strip().replace(" ", "") if X_string==None: # reverse this: X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) X_text_reversed=X_cond else: X_text_reversed=None if x_data !=None: # there is condi_image x_data_reversed=x_data #is already in sequence fromat.. else: x_data_reversed=None # summary print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,\n predicted sequence: ", y_data_reversed) # + for debug print("================================================") print("foldproteins: ", foldproteins) if not foldproteins: pdb_file=None else: # if foldproteins: if X_cond != None: pass tempname='temp' pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( sequence=y_data_reversed[0], filename_out=tempname, num_cycle=num_cycle, flag=flag, # +++++++++++++++++++ # prefix=prefix, prefix=sample_dir, ) # out_nam=iisample # # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' # write_fasta (y_data_reversed[0], out_nam_fasta) # out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' shutil.copy (pdb_file, out_nam) #source, dest shutil.copy (fasta_file, out_nam_fasta) # clean the slade to avoid mistakenly using the previous fasta file os.remove (pdb_file) os.remove (fasta_file) # pdb_file=out_nam fasta_file=out_nam_fasta # pdb_list.append(pdb_file) fasta_list.append(fasta_file) # print (f"Properly named PDB file produced: {pdb_file}") if IF_showfig==1: #flag=1000 view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() if calc_error: if CKeys['Problem_ID']==7: # only work for ModelA:SecStr get_Model_A_error (pdb_file, X[iisample], plotit=True) else: print("Error calculation on the predicted results is not applicable...") return pdb_list, fasta_list # + TBU # ++++++++++++++++++++++++++++++++++++++++++++++++ def sample_loop_omegafold_ModelA ( model, train_loader, cond_scales=None, # [7.5], #list of cond scales - each sampled... num_samples=None, # 2, #how many samples produced every time tested..... timesteps=None, # 100, # not used flag=None, # 0, foldproteins=False, # cond_image=False, # use_text_embedd=True, cond_text=True, skip_steps=0, # max_text_len=None, max_length=None, # +++++++++++++++++++ train_unet_number=1, ynormfac=None, prefix=None, tokenizer_y=None, Xnormfac_CondiText=1, tokenizer_X_CondiText=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True ): # ===================================================== # sample # = num_samples*(# of mini-batches) # ===================================================== # steps=0 # e=flag # for item in train_loader: for idx, item in enumerate(train_loader): X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) GT=y_train_batch.cpu().detach() GT= GT.unsqueeze(1) if num_samples>y_train_batch.shape[0]: print("Warning: sampling # > len(mini_batch)") num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") X_train_batch_picked = X_train_batch[:num_samples,:] print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) # loop over cond_scales:list for iisample in range (len (cond_scales)): # ++ for model A result=model.sample ( x=X_train_batch_picked, stop_at_unet_number=train_unet_number, cond_scale=cond_scales[iisample], # skip_steps=skip_steps, device=model.device, # max_length=max_length, max_text_len=max_text_len, ) # # ++ for model B # if use_text_embedd: # result=model.sample ( # # x= X_train_batch, # x= X_train_batch_picked, # stop_at_unet_number=train_unet_number , # cond_scale=cond_scales[iisample], # device=device, # skip_steps=skip_steps # ) # else: # result=model.sample ( # x= None, # # x_data_tokenized= X_train_batch, # x_data_tokenized= X_train_batch_picked, # stop_at_unet_number=train_unet_number , # cond_scale=cond_scales[iisample], # device=device, # skip_steps=skip_steps # ) result=torch.round(result*ynormfac) GT=torch.round (GT*ynormfac) for samples in range (num_samples): print ("sample ", samples+1, "out of ", num_samples) fig=plt.figure() plt.plot ( result[samples,0,:].cpu().detach().numpy(), label= f'Predicted' ) plt.plot ( GT[samples,0,:], label= f'GT {0}' ) plt.legend() outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" if IF_showfig==1: plt.show() else: plt.savefig(outname, dpi=200) plt.close () #reverse y sequence to_rev=result[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") #reverse GT_y sequence to_rev=GT[:,0,:] to_rev=to_rev.long().cpu().detach().numpy() GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") ### reverse second structure input.... # pay attension to the shape of Xnormfac # - # to_rev=torch.round (X_train_batch[:,:]*Xnormfac_CondiText) # + to_rev=torch.round (X_train_batch[:,:]*torch.FloatTensor(Xnormfac_CondiText).to(model.device)) to_rev=to_rev.long().cpu().detach().numpy() # ++ different input if CKeys['Debug_TrainerPack']==1: print("tokenizer_X_CondiText: ", tokenizer_X_CondiText) print("Xnormfac_CondiText: ", Xnormfac_CondiText) if tokenizer_X_CondiText!=None: X_data_reversed=tokenizer_X_CondiText.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") else: X_data_reversed=to_rev.copy() # + for debug if CKeys['Debug_TrainerPack']==1: print("X_data_reversed: ", X_data_reversed) print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples]) print (f"Ground truth: {GT_y_data_reversed[samples]}") if foldproteins: xbc=X_train_batch[samples,:].cpu().detach().numpy() out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) tempname='temp' pdb_file=foldandsavePDB ( sequence=y_data_reversed[samples], filename_out=tempname, num_cycle=16, flag=flag, # +++++++++++++++++++ prefix=prefix ) # #out_nam=f'{prefix}{out_nam}.pdb' # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' # ------------------------------------------------------ # sometime, this name below can get too long to fit # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb' # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # add a way to save the sampling name and results # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' if CKeys['Debug_TrainerPack']==1: print("pdb_file: ", pdb_file) print("out_nam: ", out_nam) print (f'Original PDB: {pdb_file} OUT: {out_nam}') shutil.copy (pdb_file, out_nam) #source, dest # + with open(out_nam_inX, "w") as inX_file: inX_file.write(f'{X_data_reversed[samples]}\n') pdb_file=out_nam print (f"Properly named PDB file produced: {pdb_file}") print (f"input X for sampling stored: {pdb_file}") if IF_showfig==1: view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() # steps=steps+1 # if steps>num_samples: # break # + TBU # ++++++++++++++++++++++++++++++++++++++++++++++++ def sample_loop_omegafold_pLM_ModelA ( model, train_loader, cond_scales=None, # [7.5], #list of cond scales - each sampled... num_samples=None, # 2, #how many samples produced every time tested..... timesteps=None, # 100, # not used flag=None, # 0, foldproteins=False, # cond_image=False, # use_text_embedd=True, cond_text=True, skip_steps=0, # max_text_len=None, max_length=None, # +++++++++++++++++++ train_unet_number=1, ynormfac=None, prefix=None, tokenizer_y=None, Xnormfac_CondiText=1, tokenizer_X_CondiText=None, # ++ CKeys=None, sample_dir=None, steps=None, e=None, IF_showfig=True, # effective only after foldproteins=True # ++ for pLM pLM_Model=None, pLM_Model_Name=None, image_channels=None, pLM_alphabet=None, # ++ for on-fly check: for SecStr only calc_error=False, # for check on folded results, not used for every case ): # ===================================================== # sample # = num_samples*(# of mini-batches) # ===================================================== # steps=0 # e=flag # for item in train_loader: for idx, item in enumerate(train_loader): X_train_batch= item[0].to(device) y_train_batch=item[1].to(device) GT=y_train_batch.cpu().detach() GT= GT.unsqueeze(1) if num_samples>y_train_batch.shape[0]: print("Warning: sampling # > len(mini_batch)") num_samples = min (num_samples,y_train_batch.shape[0] ) print (f"Producing {num_samples} samples...") X_train_batch_picked = X_train_batch[:num_samples,:] print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) # loop over cond_scales:list for iisample in range (len (cond_scales)): # ++ for model A result_embedding = model.sample ( x=X_train_batch_picked, stop_at_unet_number=train_unet_number, cond_scale=cond_scales[iisample], # skip_steps=skip_steps, device=model.device, # max_length=max_length, max_text_len=max_text_len, # x_data=None, x_data_tokenized=None, # tokenizer_X=tokenizer_X_CondiText, Xnormfac=Xnormfac_CondiText, ) # # ++ for model B # if use_text_embedd: # result=model.sample ( # # x= X_train_batch, # x= X_train_batch_picked, # stop_at_unet_number=train_unet_number , # cond_scale=cond_scales[iisample], # device=device, # skip_steps=skip_steps # ) # else: # result=model.sample ( # x= None, # # x_data_tokenized= X_train_batch, # x_data_tokenized= X_train_batch_picked, # stop_at_unet_number=train_unet_number , # cond_scale=cond_scales[iisample], # device=device, # skip_steps=skip_steps # ) # ++ for pLM: # full record # result_embedding as image.dim: [batch, channels, seq_len] # result_tokens.dim: [batch, seq_len] result_tokens,result_logits = convert_into_tokens( pLM_Model, result_embedding, pLM_Model_Name, ) # # -------------------------------------------- # result=torch.round(result*ynormfac) # GT=torch.round (GT*ynormfac) # ++++++++++++++++++++++++++++++++++++++++++++ result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] # # --------------------------------------------------------- # #reverse y sequence # to_rev=result[:,0,:] # to_rev=to_rev.long().cpu().detach().numpy() # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) # for iii in range (len(y_data_reversed)): # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ to_rev=result[:,0,:] # token (batch,seq_len) y_data_reversed=decode_many_ems_token_rec_for_folding( to_rev, result_logits, pLM_alphabet, pLM_Model, ) if CKeys['Debug_TrainerPack']==3: print("on y_data_reversed[0]: ", y_data_reversed[0]) # # ----------------------------------------------------------- # #reverse GT_y sequence # to_rev=GT[:,0,:] # to_rev=to_rev.long().cpu().detach().numpy() # GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) # for iii in range (len(y_data_reversed)): # GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #reverse GT_y sequence # GT should be SAFE to reverse to_rev=GT[:,0,:] GT_y_data_reversed=decode_many_ems_token_rec( to_rev, pLM_alphabet, ) ### reverse second structure input.... # pay attension to the shape of Xnormfac # - # to_rev=torch.round (X_train_batch[:,:]*Xnormfac_CondiText) # + # print("X_train_batch", X_train_batch) # print("Xnormfac_CondiText: ", Xnormfac_CondiText) # to_rev=torch.round (X_train_batch[:,:]*torch.FloatTensor(Xnormfac_CondiText).to(model.device)) # print("X_train_batch: ", X_train_batch[:,:]) # print("torch.tensor(Xnormfac_CondiText): ", torch.tensor(Xnormfac_CondiText)) to_rev=X_train_batch[:,:]*torch.tensor(Xnormfac_CondiText).to(model.device) # print("to_rev ", to_rev) # # -: convert into int64 # to_rev=to_rev.long().cpu().detach().numpy() # +: just float to_rev=to_rev.cpu().detach().numpy() # print("to_rev 2", to_rev) # ++ different input if CKeys['Debug_TrainerPack']==1: print("tokenizer_X_CondiText: ", tokenizer_X_CondiText) print("Xnormfac_CondiText: ", Xnormfac_CondiText) if tokenizer_X_CondiText!=None: # round the number into tokens to_rev = np.round(to_rev) X_data_reversed=tokenizer_X_CondiText.sequences_to_texts (to_rev) for iii in range (len(y_data_reversed)): X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") else: X_data_reversed=to_rev.copy() # + for debug if CKeys['Debug_TrainerPack']==1: print("X_data_reversed: ", X_data_reversed) print("X_data_reversed.dim: ", X_data_reversed.shape) for samples in range (num_samples): print ("sample ", samples+1, "out of ", num_samples) fig=plt.figure() plt.plot ( result[samples,0,:].cpu().detach().numpy(), label= f'Predicted' ) plt.plot ( GT[samples,0,:], label= f'GT {0}' ) plt.legend() outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" if IF_showfig==1: plt.show() else: plt.savefig(outname, dpi=200) plt.close () print (f"For input in dataloader: {X_train_batch[samples,:].cpu().detach().numpy()} or \n recovered input {X_data_reversed[samples]}") print (f"predicted sequence: {y_data_reversed[samples]}") print (f"Ground truth: {GT_y_data_reversed[samples]}") if foldproteins: # check whether the predicted sequence is valid if len(y_data_reversed[samples])>0: # # -- # xbc=X_train_batch[samples,:].cpu().detach().numpy() # # out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) # out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) # ++ xbc=X_data_reversed[samples] out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) # tempname='temp' pdb_file,fasta_file=foldandsavePDB_pdb_fasta ( sequence=y_data_reversed[samples], filename_out=tempname, num_cycle=16, flag=flag, # +++++++++++++++++++ prefix=prefix ) # #out_nam=f'{prefix}{out_nam}.pdb' # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' # ------------------------------------------------------ # sometime, this name below can get too long to fit # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb' # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ # add a way to save the sampling name and results # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' out_nam_seq=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.fasta' out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' if CKeys['Debug_TrainerPack']==1: print("pdb_file: ", pdb_file) print("out_nam: ", out_nam) print (f'Original PDB: {pdb_file} OUT: {out_nam}') shutil.copy (pdb_file, out_nam) #source, dest shutil.copy (fasta_file, out_nam_seq) # + with open(out_nam_inX, "w") as inX_file: # inX_file.write(f'{X_data_reversed[samples]}\n') inX_file.write(out_nam_content) # clean the slade to avoid mistakenly using the previous fasta file os.remove (pdb_file) os.remove (fasta_file) pdb_file=out_nam print (f"Properly named PDB file produced: {pdb_file}") print (f"input X for sampling stored: {pdb_file}") if IF_showfig==1: view=show_pdb( pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color ) view.show() if calc_error: print('On-fly check...') if CKeys['Problem_ID']==7: # only work for ModelA:SecStr get_Model_A_error (pdb_file, X_data_reversed[samples], plotit=True) else: print("Error calculation on the predicted results is not applicable...") else: print("The predicted sequence is EMPTY...")