diff --git "a/PD_pLMProbXDiff/TrainerPack.py" "b/PD_pLMProbXDiff/TrainerPack.py" new file mode 100644--- /dev/null +++ "b/PD_pLMProbXDiff/TrainerPack.py" @@ -0,0 +1,8839 @@ +import os +import time +import copy +from pathlib import Path +from math import ceil +from contextlib import contextmanager, nullcontext +from functools import partial, wraps +from collections.abc import Iterable + +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.data import random_split, DataLoader +from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from torch.cuda.amp import autocast, GradScaler + +import pytorch_warmup as warmup + +import shutil + +import esm +from einops import rearrange + +from packaging import version +__version__ = '1.9.3' + +device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" +) + +import matplotlib.pyplot as plt + +def cycle(dl): + while True: + for data in dl: + yield data + +from packaging import version + +import numpy as np + +from ema_pytorch import EMA + +from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs + +from fsspec.core import url_to_fs +from fsspec.implementations.local import LocalFileSystem + +# # -- +# from PD_SpLMxDiff.ModelPack import resize_image_to,ProteinDesigner_B +# from PD_SpLMxDiff.UtilityPack import get_Model_A_error, convert_into_tokens +# from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec,decode_many_ems_token_rec +# from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec_for_folding,decode_many_ems_token_rec_for_folding + +# from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec_for_folding_with_mask,decode_many_ems_token_rec_for_folding_with_mask,read_mask_from_input +# from PD_SpLMxDiff.UtilityPack import get_DSSP_result, string_diff +# from PD_SpLMxDiff.DataSetPack import pad_a_np_arr +# ++ +from PD_pLMProbXDiff.ModelPack import ( + resize_image_to, ProteinDesigner_B, +) +from PD_pLMProbXDiff.UtilityPack import ( + get_Model_A_error, convert_into_tokens,convert_into_tokens_using_prob, + decode_one_ems_token_rec, decode_many_ems_token_rec, + decode_one_ems_token_rec_for_folding, + decode_many_ems_token_rec_for_folding, + decode_one_ems_token_rec_for_folding_with_mask, + decode_many_ems_token_rec_for_folding_with_mask, + read_mask_from_input, + get_DSSP_result, + string_diff, + load_in_pLM, +) +from PD_pLMProbXDiff.DataSetPack import ( + pad_a_np_arr +) + +# loss function +criterion_MSE_sum = nn.MSELoss(reduction='sum') +criterion_MAE_sum = nn.L1Loss(reduction='sum') + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def cast_tuple(val, length = 1): + if isinstance(val, list): + val = tuple(val) + + return val if isinstance(val, tuple) else ((val,) * length) + +def find_first(fn, arr): + for ind, el in enumerate(arr): + if fn(el): + return ind + return -1 + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + +def group_dict_by_key(cond, d): + return_val = [dict(),dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + +def string_begins_with(prefix, str): + return str.startswith(prefix) + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +# url to fs, bucket, path - for checkpointing to cloud + +def url_to_bucket(url): + if '://' not in url: + return url + + _, suffix = url.split('://') + + if prefix in {'gs', 's3'}: + return suffix.split('/')[0] + else: + raise ValueError(f'storage type prefix "{prefix}" is not supported yet') + +# decorators + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + return inner + +def cast_torch_tensor(fn, cast_fp16 = False): + @wraps(fn) + def inner(model, *args, **kwargs): + device = kwargs.pop('_device', model.device) + cast_device = kwargs.pop('_cast_device', True) + + should_cast_fp16 = cast_fp16 and model.cast_half_at_training + + kwargs_keys = kwargs.keys() + all_args = (*args, *kwargs.values()) + split_kwargs_index = len(all_args) - len(kwargs_keys) + all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) + + if cast_device: + all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) + + if should_cast_fp16: + all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args)) + + args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] + kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) + + out = fn(model, *args, **kwargs) + return out + return inner + +# gradient accumulation functions + +def split_iterable(it, split_size): + accum = [] + for ind in range(ceil(len(it) / split_size)): + start_index = ind * split_size + accum.append(it[start_index: (start_index + split_size)]) + return accum + +def split(t, split_size = None): + if not exists(split_size): + return t + + if isinstance(t, torch.Tensor): + return t.split(split_size, dim = 0) + + if isinstance(t, Iterable): + return split_iterable(t, split_size) + + return TypeError + +def find_first(cond, arr): + for el in arr: + if cond(el): + return el + return None + +def split_args_and_kwargs(*args, split_size = None, **kwargs): + all_args = (*args, *kwargs.values()) + len_all_args = len(all_args) + first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) + assert exists(first_tensor) + + batch_size = len(first_tensor) + split_size = default(split_size, batch_size) + num_chunks = ceil(batch_size / split_size) + + dict_len = len(kwargs) + dict_keys = kwargs.keys() + split_kwargs_index = len_all_args - dict_len + + split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] + chunk_sizes = tuple(map(len, split_all_args[0])) + + for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): + chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] + chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) + chunk_size_frac = chunk_size / batch_size + yield chunk_size_frac, (chunked_args, chunked_kwargs) + +# imagen trainer + +def imagen_sample_in_chunks(fn): + @wraps(fn) + def inner(self, *args, max_batch_size = None, **kwargs): + if not exists(max_batch_size): + return fn(self, *args, **kwargs) + + if self.imagen.unconditional: + batch_size = kwargs.get('batch_size') + batch_sizes = num_to_groups(batch_size, max_batch_size) + outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] + else: + outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] + + if isinstance(outputs[0], torch.Tensor): + return torch.cat(outputs, dim = 0) + + return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs)))) + + return inner + + +def restore_parts(state_dict_target, state_dict_from): + for name, param in state_dict_from.items(): + + if name not in state_dict_target: + continue + + if param.size() == state_dict_target[name].size(): + state_dict_target[name].copy_(param) + else: + print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}") + + return state_dict_target + +class ImagenTrainer(nn.Module): + locked = False + + def __init__( + self, + #imagen = None, + model = None, + + imagen_checkpoint_path = None, + use_ema = True, + lr = 1e-4, + eps = 1e-8, + beta1 = 0.9, + beta2 = 0.99, + max_grad_norm = None, + group_wd_params = True, + warmup_steps = None, + cosine_decay_max_steps = None, + only_train_unet_number = None, + fp16 = False, + precision = None, + split_batches = True, + dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), + verbose = True, + split_valid_fraction = 0.025, + split_valid_from_train = False, + split_random_seed = 42, + checkpoint_path = None, + checkpoint_every = None, + checkpoint_fs = None, + fs_kwargs: dict = None, + max_checkpoints_keep = 20, + # +++++++++++++++++++++ + CKeys=None, + # + **kwargs + ): + super().__init__() + assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' + assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' + + # determine filesystem, using fsspec, for saving to local filesystem or cloud + + self.fs = checkpoint_fs + + if not exists(self.fs): + fs_kwargs = default(fs_kwargs, {}) + self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) + + # # ----------------------------------- + # # from MJB + # assert isinstance(model.imagen, (ProteinDesigner_B)) + # modified by BN + # ++: try this trainer for all models + # assert isinstance(model, (ProteinDesigner_B)) + + # +++++++++++++++++++++++++ + self.CKeys = CKeys + + ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + + + self.imagen = model.imagen + + + + self.model=model + self.is_elucidated = self.model.is_elucidated + # create accelerator instance + + accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) + + assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' + accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') + + self.accelerator = Accelerator(**{ + 'split_batches': split_batches, + 'mixed_precision': accelerator_mixed_precision, + 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] + , **accelerate_kwargs}) + + ImagenTrainer.locked = self.is_distributed + + # cast data to fp16 at training time if needed + + self.cast_half_at_training = accelerator_mixed_precision == 'fp16' + + # grad scaler must be managed outside of accelerator + + grad_scaler_enabled = fp16 + + self.num_unets = len(self.imagen.unets) + + self.use_ema = use_ema and self.is_main + self.ema_unets = nn.ModuleList([]) + + # keep track of what unet is being trained on + # only going to allow 1 unet training at a time + + self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on + + # data related functions + + self.train_dl_iter = None + self.train_dl = None + + self.valid_dl_iter = None + self.valid_dl = None + + self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names + + # auto splitting validation from training, if dataset is passed in + + self.split_valid_from_train = split_valid_from_train + + assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' + self.split_valid_fraction = split_valid_fraction + self.split_random_seed = split_random_seed + + # be able to finely customize learning rate, weight decay + # per unet + + lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) + + for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): + optimizer = Adam( + unet.parameters(), + lr = unet_lr, + eps = unet_eps, + betas = (beta1, beta2), + **kwargs + ) + + if self.use_ema: + self.ema_unets.append(EMA(unet, **ema_kwargs)) + + scaler = GradScaler(enabled = grad_scaler_enabled) + + scheduler = warmup_scheduler = None + + if exists(unet_cosine_decay_max_steps): + scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) + + if exists(unet_warmup_steps): + warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) + + if not exists(scheduler): + scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) + + # set on object + + setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers + setattr(self, f'scaler{ind}', scaler) + setattr(self, f'scheduler{ind}', scheduler) + setattr(self, f'warmup{ind}', warmup_scheduler) + + # gradient clipping if needed + + self.max_grad_norm = max_grad_norm + + # step tracker and misc + + self.register_buffer('steps', torch.tensor([0] * self.num_unets)) + + self.verbose = verbose + + # automatic set devices based on what accelerator decided + + self.imagen.to(self.device) + self.to(self.device) + + # checkpointing + + assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) + self.checkpoint_path = checkpoint_path + self.checkpoint_every = checkpoint_every + self.max_checkpoints_keep = max_checkpoints_keep + + self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main + + if exists(checkpoint_path) and self.can_checkpoint: + bucket = url_to_bucket(checkpoint_path) + + if not self.fs.exists(bucket): + self.fs.mkdir(bucket) + + self.load_from_checkpoint_folder() + + # only allowing training for unet + + self.only_train_unet_number = only_train_unet_number + self.validate_and_set_unet_being_trained(only_train_unet_number) + + # computed values + + @property + def device(self): + return self.accelerator.device + + @property + def is_distributed(self): + return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) + + @property + def is_main(self): + return self.accelerator.is_main_process + + @property + def is_local_main(self): + return self.accelerator.is_local_main_process + + @property + def unwrapped_unet(self): + return self.accelerator.unwrap_model(self.unet_being_trained) + + # optimizer helper functions + + def get_lr(self, unet_number): + self.validate_unet_number(unet_number) + unet_index = unet_number - 1 + + optim = getattr(self, f'optim{unet_index}') + + return optim.param_groups[0]['lr'] + + # function for allowing only one unet from being trained at a time + + def validate_and_set_unet_being_trained(self, unet_number = None): + if exists(unet_number): + self.validate_unet_number(unet_number) + + assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' + + self.only_train_unet_number = unet_number + self.imagen.only_train_unet_number = unet_number + + if not exists(unet_number): + return + + self.wrap_unet(unet_number) + + def wrap_unet(self, unet_number): + if hasattr(self, 'one_unet_wrapped'): + return + + unet = self.imagen.get_unet(unet_number) + self.unet_being_trained = self.accelerator.prepare(unet) + unet_index = unet_number - 1 + + optimizer = getattr(self, f'optim{unet_index}') + scheduler = getattr(self, f'scheduler{unet_index}') + + optimizer = self.accelerator.prepare(optimizer) + + if exists(scheduler): + scheduler = self.accelerator.prepare(scheduler) + + setattr(self, f'optim{unet_index}', optimizer) + setattr(self, f'scheduler{unet_index}', scheduler) + + self.one_unet_wrapped = True + + # hacking accelerator due to not having separate gradscaler per optimizer + + def set_accelerator_scaler(self, unet_number): + unet_number = self.validate_unet_number(unet_number) + scaler = getattr(self, f'scaler{unet_number - 1}') + + self.accelerator.scaler = scaler + for optimizer in self.accelerator._optimizers: + optimizer.scaler = scaler + + # helper print + + def print(self, msg): + if not self.is_main: + return + + if not self.verbose: + return + + return self.accelerator.print(msg) + + # validating the unet number + + def validate_unet_number(self, unet_number = None): + if self.num_unets == 1: + unet_number = default(unet_number, 1) + + assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' + return unet_number + + # number of training steps taken + + def num_steps_taken(self, unet_number = None): + if self.num_unets == 1: + unet_number = default(unet_number, 1) + + return self.steps[unet_number - 1].item() + + def print_untrained_unets(self): + print_final_error = False + + for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): + if steps > 0 or isinstance(unet, NullUnet): + continue + + self.print(f'unet {ind + 1} has not been trained') + print_final_error = True + + if print_final_error: + self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') + + # data related functions + + def add_train_dataloader(self, dl = None): + if not exists(dl): + return + + assert not exists(self.train_dl), 'training dataloader was already added' + self.train_dl = self.accelerator.prepare(dl) + + def add_valid_dataloader(self, dl): + if not exists(dl): + return + + assert not exists(self.valid_dl), 'validation dataloader was already added' + self.valid_dl = self.accelerator.prepare(dl) + + def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): + if not exists(ds): + return + + assert not exists(self.train_dl), 'training dataloader was already added' + + valid_ds = None + if self.split_valid_from_train: + train_size = int((1 - self.split_valid_fraction) * len(ds)) + valid_size = len(ds) - train_size + + ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) + self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') + + dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) + self.train_dl = self.accelerator.prepare(dl) + + if not self.split_valid_from_train: + return + + self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) + + def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): + if not exists(ds): + return + + assert not exists(self.valid_dl), 'validation dataloader was already added' + + dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) + self.valid_dl = self.accelerator.prepare(dl) + + def create_train_iter(self): + assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' + + if exists(self.train_dl_iter): + return + + self.train_dl_iter = cycle(self.train_dl) + + def create_valid_iter(self): + assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' + + if exists(self.valid_dl_iter): + return + + self.valid_dl_iter = cycle(self.valid_dl) + + def train_step(self, unet_number = None, **kwargs): + self.create_train_iter() + loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) + self.update(unet_number = unet_number) + return loss + + @torch.no_grad() + @eval_decorator + def valid_step(self, **kwargs): + self.create_valid_iter() + + context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext + + with context(): + loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) + return loss + + def step_with_dl_iter(self, dl_iter, **kwargs): + dl_tuple_output = cast_tuple(next(dl_iter)) + model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) + loss = self.forward(**{**kwargs, **model_input}) + return loss + + # checkpointing functions + + @property + def all_checkpoints_sorted(self): + glob_pattern = os.path.join(self.checkpoint_path, '*.pt') + checkpoints = self.fs.glob(glob_pattern) + sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) + return sorted_checkpoints + + def load_from_checkpoint_folder(self, last_total_steps = -1): + if last_total_steps != -1: + filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') + self.load(filepath) + return + + sorted_checkpoints = self.all_checkpoints_sorted + + if len(sorted_checkpoints) == 0: + self.print(f'no checkpoints found to load from at {self.checkpoint_path}') + return + + last_checkpoint = sorted_checkpoints[0] + self.load(last_checkpoint) + + def save_to_checkpoint_folder(self): + self.accelerator.wait_for_everyone() + + if not self.can_checkpoint: + return + + total_steps = int(self.steps.sum().item()) + filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') + + self.save(filepath) + + if self.max_checkpoints_keep <= 0: + return + + sorted_checkpoints = self.all_checkpoints_sorted + checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] + + for checkpoint in checkpoints_to_discard: + self.fs.rm(checkpoint) + + # saving and loading functions + + def save( + self, + path, + overwrite = True, + without_optim_and_sched = False, + **kwargs + ): + self.accelerator.wait_for_everyone() + + if not self.can_checkpoint: + return + + fs = self.fs + + assert not (fs.exists(path) and not overwrite) + + self.reset_ema_unets_all_one_device() + + save_obj = dict( + model = self.imagen.state_dict(), + version = __version__, + steps = self.steps.cpu(), + **kwargs + ) + + save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() + + for ind in save_optim_and_sched_iter: + scaler_key = f'scaler{ind}' + optimizer_key = f'optim{ind}' + scheduler_key = f'scheduler{ind}' + warmup_scheduler_key = f'warmup{ind}' + + scaler = getattr(self, scaler_key) + optimizer = getattr(self, optimizer_key) + scheduler = getattr(self, scheduler_key) + warmup_scheduler = getattr(self, warmup_scheduler_key) + + if exists(scheduler): + save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} + + if exists(warmup_scheduler): + save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} + + save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} + + if self.use_ema: + save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} + + # determine if imagen config is available + + if hasattr(self.imagen, '_config'): + self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""') + + save_obj = { + **save_obj, + 'imagen_type': 'elucidated' if self.is_elucidated else 'original', + 'imagen_params': self.imagen._config + } + + #save to path + + with fs.open(path, 'wb') as f: + torch.save(save_obj, f) + + self.print(f'checkpoint saved to {path}') + + def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): + fs = self.fs + + if noop_if_not_exist and not fs.exists(path): + self.print(f'trainer checkpoint not found at {str(path)}') + return + + assert fs.exists(path), f'{path} does not exist' + + self.reset_ema_unets_all_one_device() + + # to avoid extra GPU memory usage in main process when using Accelerate + + with fs.open(path) as f: + loaded_obj = torch.load(f, map_location='cpu') + + if version.parse(__version__) != version.parse(loaded_obj['version']): + self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') + + try: + self.imagen.load_state_dict(loaded_obj['model'], strict = strict) + except RuntimeError: + print("Failed loading state dict. Trying partial load") + self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), + loaded_obj['model'])) + + if only_model: + return loaded_obj + + self.steps.copy_(loaded_obj['steps']) + + for ind in range(0, self.num_unets): + scaler_key = f'scaler{ind}' + optimizer_key = f'optim{ind}' + scheduler_key = f'scheduler{ind}' + warmup_scheduler_key = f'warmup{ind}' + + scaler = getattr(self, scaler_key) + optimizer = getattr(self, optimizer_key) + scheduler = getattr(self, scheduler_key) + warmup_scheduler = getattr(self, warmup_scheduler_key) + + if exists(scheduler) and scheduler_key in loaded_obj: + scheduler.load_state_dict(loaded_obj[scheduler_key]) + + if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: + warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) + + if exists(optimizer): + try: + optimizer.load_state_dict(loaded_obj[optimizer_key]) + scaler.load_state_dict(loaded_obj[scaler_key]) + except: + self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') + + if self.use_ema: + assert 'ema' in loaded_obj + try: + self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) + except RuntimeError: + print("Failed loading state dict. Trying partial load") + self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), + loaded_obj['ema'])) + + self.print(f'checkpoint loaded from {path}') + return loaded_obj + + # managing ema unets and their devices + + @property + def unets(self): + return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) + + def get_ema_unet(self, unet_number = None): + if not self.use_ema: + return + + unet_number = self.validate_unet_number(unet_number) + index = unet_number - 1 + + if isinstance(self.unets, nn.ModuleList): + unets_list = [unet for unet in self.ema_unets] + delattr(self, 'ema_unets') + self.ema_unets = unets_list + + if index != self.ema_unet_being_trained_index: + for unet_index, unet in enumerate(self.ema_unets): + unet.to(self.device if unet_index == index else 'cpu') + + self.ema_unet_being_trained_index = index + return self.ema_unets[index] + + def reset_ema_unets_all_one_device(self, device = None): + if not self.use_ema: + return + + device = default(device, self.device) + self.ema_unets = nn.ModuleList([*self.ema_unets]) + self.ema_unets.to(device) + + self.ema_unet_being_trained_index = -1 + + @torch.no_grad() + @contextmanager + def use_ema_unets(self): + if not self.use_ema: + output = yield + return output + + self.reset_ema_unets_all_one_device() + self.imagen.reset_unets_all_one_device() + + self.unets.eval() + + trainable_unets = self.imagen.unets + self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling + + output = yield + + self.imagen.unets = trainable_unets # restore original training unets + + # cast the ema_model unets back to original device + for ema in self.ema_unets: + ema.restore_ema_model_device() + + return output + + def print_unet_devices(self): + self.print('unet devices:') + for i, unet in enumerate(self.imagen.unets): + device = next(unet.parameters()).device + self.print(f'\tunet {i}: {device}') + + if not self.use_ema: + return + + self.print('\nema unet devices:') + for i, ema_unet in enumerate(self.ema_unets): + device = next(ema_unet.parameters()).device + self.print(f'\tema unet {i}: {device}') + + # overriding state dict functions + + def state_dict(self, *args, **kwargs): + self.reset_ema_unets_all_one_device() + return super().state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + self.reset_ema_unets_all_one_device() + return super().load_state_dict(*args, **kwargs) + + # encoding text functions + + def encode_text(self, text, **kwargs): + return self.imagen.encode_text(text, **kwargs) + + # forwarding functions and gradient step updates + + def update(self, unet_number = None): + unet_number = self.validate_unet_number(unet_number) + self.validate_and_set_unet_being_trained(unet_number) + self.set_accelerator_scaler(unet_number) + + index = unet_number - 1 + unet = self.unet_being_trained + + optimizer = getattr(self, f'optim{index}') + scaler = getattr(self, f'scaler{index}') + scheduler = getattr(self, f'scheduler{index}') + warmup_scheduler = getattr(self, f'warmup{index}') + + # set the grad scaler on the accelerator, since we are managing one per u-net + + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) + + optimizer.step() + optimizer.zero_grad() + + if self.use_ema: + ema_unet = self.get_ema_unet(unet_number) + ema_unet.update() + + # scheduler, if needed + + maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() + + with maybe_warmup_context: + if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs + scheduler.step() + + self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) + + if not exists(self.checkpoint_path): + return + + total_steps = int(self.steps.sum().item()) + + if total_steps % self.checkpoint_every: + return + + self.save_to_checkpoint_folder() + + @torch.no_grad() + @cast_torch_tensor + @imagen_sample_in_chunks + def sample(self, *args, **kwargs): + context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets + + self.print_untrained_unets() + + if not self.is_main: + kwargs['use_tqdm'] = False + + with context(): + output = self.imagen.sample(*args, device = self.device, **kwargs) + + return output + + @partial(cast_torch_tensor, cast_fp16 = True) + def forward( + self, + *args, + unet_number = None, + max_batch_size = None, + **kwargs + ): + unet_number = self.validate_unet_number(unet_number) + self.validate_and_set_unet_being_trained(unet_number) + self.set_accelerator_scaler(unet_number) + + assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' + + total_loss = 0. + + + # + for debug + if self.CKeys['Debug_TrainerPack']==1: + print("In Trainer:Forward, check inputs:") + print('args: ', len(args)) + print('args in:',args[0].shape) + print('kwargs: ', kwargs.keys()) + for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): + # + for debug + if self.CKeys['Debug_TrainerPack']==1: + print("after chunks,...") + print('chun_frac: ', chunk_size_frac) + print('chun_args: ', chunked_args) + print('chun_kwargs: ', chunked_kwargs) + + with self.accelerator.autocast(): + loss = self.model( + *chunked_args, + unet_number = unet_number, + **chunked_kwargs + ) + loss = loss * chunk_size_frac + + # + for debug + if self.CKeys['Debug_TrainerPack']==1: + print('part chun loss: ', loss) + + total_loss += loss#.item() + + if self.training: + self.accelerator.backward(loss) + + return total_loss + +# ======================================================== +# +class ImagenTrainer_ModelB(nn.Module): + locked = False + + def __init__( + self, + #imagen = None, + model = None, + + imagen_checkpoint_path = None, + use_ema = True, + lr = 1e-4, + eps = 1e-8, + beta1 = 0.9, + beta2 = 0.99, + max_grad_norm = None, + group_wd_params = True, + warmup_steps = None, + cosine_decay_max_steps = None, + only_train_unet_number = None, + fp16 = False, + precision = None, + split_batches = True, + dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), + verbose = True, + split_valid_fraction = 0.025, + split_valid_from_train = False, + split_random_seed = 42, + checkpoint_path = None, + checkpoint_every = None, + checkpoint_fs = None, + fs_kwargs: dict = None, + max_checkpoints_keep = 20, + # +++++++++++++++++++++ + CKeys=None, + # + **kwargs + ): + super().__init__() + assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' + assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' + + # determine filesystem, using fsspec, for saving to local filesystem or cloud + + self.fs = checkpoint_fs + + if not exists(self.fs): + fs_kwargs = default(fs_kwargs, {}) + self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) + + # # ----------------------------------- + # # from MJB + # assert isinstance(model.imagen, (ProteinDesigner_B)) + # modified by BN + # ++ + assert isinstance(model, (ProteinDesigner_B)) + + # +++++++++++++++++++++++++ + self.CKeys = CKeys + + ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + + + self.imagen = model.imagen + + + + self.model=model + self.is_elucidated = self.model.is_elucidated + # create accelerator instance + + accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) + + assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' + accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') + + self.accelerator = Accelerator(**{ + 'split_batches': split_batches, + 'mixed_precision': accelerator_mixed_precision, + 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] + , **accelerate_kwargs}) + + ImagenTrainer.locked = self.is_distributed + + # cast data to fp16 at training time if needed + + self.cast_half_at_training = accelerator_mixed_precision == 'fp16' + + # grad scaler must be managed outside of accelerator + + grad_scaler_enabled = fp16 + + self.num_unets = len(self.imagen.unets) + + self.use_ema = use_ema and self.is_main + self.ema_unets = nn.ModuleList([]) + + # keep track of what unet is being trained on + # only going to allow 1 unet training at a time + + self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on + + # data related functions + + self.train_dl_iter = None + self.train_dl = None + + self.valid_dl_iter = None + self.valid_dl = None + + self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names + + # auto splitting validation from training, if dataset is passed in + + self.split_valid_from_train = split_valid_from_train + + assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' + self.split_valid_fraction = split_valid_fraction + self.split_random_seed = split_random_seed + + # be able to finely customize learning rate, weight decay + # per unet + + lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) + + for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): + optimizer = Adam( + unet.parameters(), + lr = unet_lr, + eps = unet_eps, + betas = (beta1, beta2), + **kwargs + ) + + if self.use_ema: + self.ema_unets.append(EMA(unet, **ema_kwargs)) + + scaler = GradScaler(enabled = grad_scaler_enabled) + + scheduler = warmup_scheduler = None + + if exists(unet_cosine_decay_max_steps): + scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) + + if exists(unet_warmup_steps): + warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) + + if not exists(scheduler): + scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) + + # set on object + + setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers + setattr(self, f'scaler{ind}', scaler) + setattr(self, f'scheduler{ind}', scheduler) + setattr(self, f'warmup{ind}', warmup_scheduler) + + # gradient clipping if needed + + self.max_grad_norm = max_grad_norm + + # step tracker and misc + + self.register_buffer('steps', torch.tensor([0] * self.num_unets)) + + self.verbose = verbose + + # automatic set devices based on what accelerator decided + + self.imagen.to(self.device) + self.to(self.device) + + # checkpointing + + assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) + self.checkpoint_path = checkpoint_path + self.checkpoint_every = checkpoint_every + self.max_checkpoints_keep = max_checkpoints_keep + + self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main + + if exists(checkpoint_path) and self.can_checkpoint: + bucket = url_to_bucket(checkpoint_path) + + if not self.fs.exists(bucket): + self.fs.mkdir(bucket) + + self.load_from_checkpoint_folder() + + # only allowing training for unet + + self.only_train_unet_number = only_train_unet_number + self.validate_and_set_unet_being_trained(only_train_unet_number) + + # computed values + + @property + def device(self): + return self.accelerator.device + + @property + def is_distributed(self): + return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) + + @property + def is_main(self): + return self.accelerator.is_main_process + + @property + def is_local_main(self): + return self.accelerator.is_local_main_process + + @property + def unwrapped_unet(self): + return self.accelerator.unwrap_model(self.unet_being_trained) + + # optimizer helper functions + + def get_lr(self, unet_number): + self.validate_unet_number(unet_number) + unet_index = unet_number - 1 + + optim = getattr(self, f'optim{unet_index}') + + return optim.param_groups[0]['lr'] + + # function for allowing only one unet from being trained at a time + + def validate_and_set_unet_being_trained(self, unet_number = None): + if exists(unet_number): + self.validate_unet_number(unet_number) + + assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' + + self.only_train_unet_number = unet_number + self.imagen.only_train_unet_number = unet_number + + if not exists(unet_number): + return + + self.wrap_unet(unet_number) + + def wrap_unet(self, unet_number): + if hasattr(self, 'one_unet_wrapped'): + return + + unet = self.imagen.get_unet(unet_number) + self.unet_being_trained = self.accelerator.prepare(unet) + unet_index = unet_number - 1 + + optimizer = getattr(self, f'optim{unet_index}') + scheduler = getattr(self, f'scheduler{unet_index}') + + optimizer = self.accelerator.prepare(optimizer) + + if exists(scheduler): + scheduler = self.accelerator.prepare(scheduler) + + setattr(self, f'optim{unet_index}', optimizer) + setattr(self, f'scheduler{unet_index}', scheduler) + + self.one_unet_wrapped = True + + # hacking accelerator due to not having separate gradscaler per optimizer + + def set_accelerator_scaler(self, unet_number): + unet_number = self.validate_unet_number(unet_number) + scaler = getattr(self, f'scaler{unet_number - 1}') + + self.accelerator.scaler = scaler + for optimizer in self.accelerator._optimizers: + optimizer.scaler = scaler + + # helper print + + def print(self, msg): + if not self.is_main: + return + + if not self.verbose: + return + + return self.accelerator.print(msg) + + # validating the unet number + + def validate_unet_number(self, unet_number = None): + if self.num_unets == 1: + unet_number = default(unet_number, 1) + + assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' + return unet_number + + # number of training steps taken + + def num_steps_taken(self, unet_number = None): + if self.num_unets == 1: + unet_number = default(unet_number, 1) + + return self.steps[unet_number - 1].item() + + def print_untrained_unets(self): + print_final_error = False + + for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): + if steps > 0 or isinstance(unet, NullUnet): + continue + + self.print(f'unet {ind + 1} has not been trained') + print_final_error = True + + if print_final_error: + self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') + + # data related functions + + def add_train_dataloader(self, dl = None): + if not exists(dl): + return + + assert not exists(self.train_dl), 'training dataloader was already added' + self.train_dl = self.accelerator.prepare(dl) + + def add_valid_dataloader(self, dl): + if not exists(dl): + return + + assert not exists(self.valid_dl), 'validation dataloader was already added' + self.valid_dl = self.accelerator.prepare(dl) + + def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): + if not exists(ds): + return + + assert not exists(self.train_dl), 'training dataloader was already added' + + valid_ds = None + if self.split_valid_from_train: + train_size = int((1 - self.split_valid_fraction) * len(ds)) + valid_size = len(ds) - train_size + + ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) + self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') + + dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) + self.train_dl = self.accelerator.prepare(dl) + + if not self.split_valid_from_train: + return + + self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) + + def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): + if not exists(ds): + return + + assert not exists(self.valid_dl), 'validation dataloader was already added' + + dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) + self.valid_dl = self.accelerator.prepare(dl) + + def create_train_iter(self): + assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' + + if exists(self.train_dl_iter): + return + + self.train_dl_iter = cycle(self.train_dl) + + def create_valid_iter(self): + assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' + + if exists(self.valid_dl_iter): + return + + self.valid_dl_iter = cycle(self.valid_dl) + + def train_step(self, unet_number = None, **kwargs): + self.create_train_iter() + loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) + self.update(unet_number = unet_number) + return loss + + @torch.no_grad() + @eval_decorator + def valid_step(self, **kwargs): + self.create_valid_iter() + + context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext + + with context(): + loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) + return loss + + def step_with_dl_iter(self, dl_iter, **kwargs): + dl_tuple_output = cast_tuple(next(dl_iter)) + model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) + loss = self.forward(**{**kwargs, **model_input}) + return loss + + # checkpointing functions + + @property + def all_checkpoints_sorted(self): + glob_pattern = os.path.join(self.checkpoint_path, '*.pt') + checkpoints = self.fs.glob(glob_pattern) + sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) + return sorted_checkpoints + + def load_from_checkpoint_folder(self, last_total_steps = -1): + if last_total_steps != -1: + filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') + self.load(filepath) + return + + sorted_checkpoints = self.all_checkpoints_sorted + + if len(sorted_checkpoints) == 0: + self.print(f'no checkpoints found to load from at {self.checkpoint_path}') + return + + last_checkpoint = sorted_checkpoints[0] + self.load(last_checkpoint) + + def save_to_checkpoint_folder(self): + self.accelerator.wait_for_everyone() + + if not self.can_checkpoint: + return + + total_steps = int(self.steps.sum().item()) + filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') + + self.save(filepath) + + if self.max_checkpoints_keep <= 0: + return + + sorted_checkpoints = self.all_checkpoints_sorted + checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] + + for checkpoint in checkpoints_to_discard: + self.fs.rm(checkpoint) + + # saving and loading functions + + def save( + self, + path, + overwrite = True, + without_optim_and_sched = False, + **kwargs + ): + self.accelerator.wait_for_everyone() + + if not self.can_checkpoint: + return + + fs = self.fs + + assert not (fs.exists(path) and not overwrite) + + self.reset_ema_unets_all_one_device() + + save_obj = dict( + model = self.imagen.state_dict(), + version = __version__, + steps = self.steps.cpu(), + **kwargs + ) + + save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() + + for ind in save_optim_and_sched_iter: + scaler_key = f'scaler{ind}' + optimizer_key = f'optim{ind}' + scheduler_key = f'scheduler{ind}' + warmup_scheduler_key = f'warmup{ind}' + + scaler = getattr(self, scaler_key) + optimizer = getattr(self, optimizer_key) + scheduler = getattr(self, scheduler_key) + warmup_scheduler = getattr(self, warmup_scheduler_key) + + if exists(scheduler): + save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} + + if exists(warmup_scheduler): + save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} + + save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} + + if self.use_ema: + save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} + + # determine if imagen config is available + + if hasattr(self.imagen, '_config'): + self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""') + + save_obj = { + **save_obj, + 'imagen_type': 'elucidated' if self.is_elucidated else 'original', + 'imagen_params': self.imagen._config + } + + #save to path + + with fs.open(path, 'wb') as f: + torch.save(save_obj, f) + + self.print(f'checkpoint saved to {path}') + + def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): + fs = self.fs + + if noop_if_not_exist and not fs.exists(path): + self.print(f'trainer checkpoint not found at {str(path)}') + return + + assert fs.exists(path), f'{path} does not exist' + + self.reset_ema_unets_all_one_device() + + # to avoid extra GPU memory usage in main process when using Accelerate + + with fs.open(path) as f: + loaded_obj = torch.load(f, map_location='cpu') + + if version.parse(__version__) != version.parse(loaded_obj['version']): + self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') + + try: + self.imagen.load_state_dict(loaded_obj['model'], strict = strict) + except RuntimeError: + print("Failed loading state dict. Trying partial load") + self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), + loaded_obj['model'])) + + if only_model: + return loaded_obj + + self.steps.copy_(loaded_obj['steps']) + + for ind in range(0, self.num_unets): + scaler_key = f'scaler{ind}' + optimizer_key = f'optim{ind}' + scheduler_key = f'scheduler{ind}' + warmup_scheduler_key = f'warmup{ind}' + + scaler = getattr(self, scaler_key) + optimizer = getattr(self, optimizer_key) + scheduler = getattr(self, scheduler_key) + warmup_scheduler = getattr(self, warmup_scheduler_key) + + if exists(scheduler) and scheduler_key in loaded_obj: + scheduler.load_state_dict(loaded_obj[scheduler_key]) + + if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: + warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) + + if exists(optimizer): + try: + optimizer.load_state_dict(loaded_obj[optimizer_key]) + scaler.load_state_dict(loaded_obj[scaler_key]) + except: + self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') + + if self.use_ema: + assert 'ema' in loaded_obj + try: + self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) + except RuntimeError: + print("Failed loading state dict. Trying partial load") + self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), + loaded_obj['ema'])) + + self.print(f'checkpoint loaded from {path}') + return loaded_obj + + # managing ema unets and their devices + + @property + def unets(self): + return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) + + def get_ema_unet(self, unet_number = None): + if not self.use_ema: + return + + unet_number = self.validate_unet_number(unet_number) + index = unet_number - 1 + + if isinstance(self.unets, nn.ModuleList): + unets_list = [unet for unet in self.ema_unets] + delattr(self, 'ema_unets') + self.ema_unets = unets_list + + if index != self.ema_unet_being_trained_index: + for unet_index, unet in enumerate(self.ema_unets): + unet.to(self.device if unet_index == index else 'cpu') + + self.ema_unet_being_trained_index = index + return self.ema_unets[index] + + def reset_ema_unets_all_one_device(self, device = None): + if not self.use_ema: + return + + device = default(device, self.device) + self.ema_unets = nn.ModuleList([*self.ema_unets]) + self.ema_unets.to(device) + + self.ema_unet_being_trained_index = -1 + + @torch.no_grad() + @contextmanager + def use_ema_unets(self): + if not self.use_ema: + output = yield + return output + + self.reset_ema_unets_all_one_device() + self.imagen.reset_unets_all_one_device() + + self.unets.eval() + + trainable_unets = self.imagen.unets + self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling + + output = yield + + self.imagen.unets = trainable_unets # restore original training unets + + # cast the ema_model unets back to original device + for ema in self.ema_unets: + ema.restore_ema_model_device() + + return output + + def print_unet_devices(self): + self.print('unet devices:') + for i, unet in enumerate(self.imagen.unets): + device = next(unet.parameters()).device + self.print(f'\tunet {i}: {device}') + + if not self.use_ema: + return + + self.print('\nema unet devices:') + for i, ema_unet in enumerate(self.ema_unets): + device = next(ema_unet.parameters()).device + self.print(f'\tema unet {i}: {device}') + + # overriding state dict functions + + def state_dict(self, *args, **kwargs): + self.reset_ema_unets_all_one_device() + return super().state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + self.reset_ema_unets_all_one_device() + return super().load_state_dict(*args, **kwargs) + + # encoding text functions + + def encode_text(self, text, **kwargs): + return self.imagen.encode_text(text, **kwargs) + + # forwarding functions and gradient step updates + + def update(self, unet_number = None): + unet_number = self.validate_unet_number(unet_number) + self.validate_and_set_unet_being_trained(unet_number) + self.set_accelerator_scaler(unet_number) + + index = unet_number - 1 + unet = self.unet_being_trained + + optimizer = getattr(self, f'optim{index}') + scaler = getattr(self, f'scaler{index}') + scheduler = getattr(self, f'scheduler{index}') + warmup_scheduler = getattr(self, f'warmup{index}') + + # set the grad scaler on the accelerator, since we are managing one per u-net + + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) + + optimizer.step() + optimizer.zero_grad() + + if self.use_ema: + ema_unet = self.get_ema_unet(unet_number) + ema_unet.update() + + # scheduler, if needed + + maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() + + with maybe_warmup_context: + if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs + scheduler.step() + + self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) + + if not exists(self.checkpoint_path): + return + + total_steps = int(self.steps.sum().item()) + + if total_steps % self.checkpoint_every: + return + + self.save_to_checkpoint_folder() + + @torch.no_grad() + @cast_torch_tensor + @imagen_sample_in_chunks + def sample(self, *args, **kwargs): + context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets + + self.print_untrained_unets() + + if not self.is_main: + kwargs['use_tqdm'] = False + + with context(): + output = self.imagen.sample(*args, device = self.device, **kwargs) + + return output + + @partial(cast_torch_tensor, cast_fp16 = True) + def forward( + self, + *args, + unet_number = None, + max_batch_size = None, + **kwargs + ): + unet_number = self.validate_unet_number(unet_number) + self.validate_and_set_unet_being_trained(unet_number) + self.set_accelerator_scaler(unet_number) + + assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' + + total_loss = 0. + + + # + for debug + if self.CKeys['Debug_TrainerPack']==1: + print("In Trainer:Forward, check inputs:") + print('args: ', len(args)) + print('args in:',args[0].shape) + print('kwargs: ', kwargs.keys()) + for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): + # + for debug + if self.CKeys['Debug_TrainerPack']==1: + print("after chunks,...") + print('chun_frac: ', chunk_size_frac) + print('chun_args: ', chunked_args) + print('chun_kwargs: ', chunked_kwargs) + + with self.accelerator.autocast(): + loss = self.model( + *chunked_args, + unet_number = unet_number, + **chunked_kwargs + ) + loss = loss * chunk_size_frac + + # + for debug + if self.CKeys['Debug_TrainerPack']==1: + print('part chun loss: ', loss) + + total_loss += loss#.item() + + if self.training: + self.accelerator.backward(loss) + + return total_loss + +class ImagenTrainer_Old(nn.Module): + locked = False + + def __init__( + self, + #imagen = None, + model = None, + + imagen_checkpoint_path = None, + use_ema = True, + lr = 1e-4, + eps = 1e-8, + beta1 = 0.9, + beta2 = 0.99, + max_grad_norm = None, + group_wd_params = True, + warmup_steps = None, + cosine_decay_max_steps = None, + only_train_unet_number = None, + fp16 = False, + precision = None, + split_batches = True, + dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), + verbose = True, + split_valid_fraction = 0.025, + split_valid_from_train = False, + split_random_seed = 42, + checkpoint_path = None, + checkpoint_every = None, + checkpoint_fs = None, + fs_kwargs: dict = None, + max_checkpoints_keep = 20, + **kwargs + ): + super().__init__() + assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' + assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' + + # determine filesystem, using fsspec, for saving to local filesystem or cloud + + self.fs = checkpoint_fs + + if not exists(self.fs): + fs_kwargs = default(fs_kwargs, {}) + self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) + + # # ----------------------------------- + # # from MJB + # assert isinstance(model.imagen, (ProteinDesigner_B)) + + ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + + + self.imagen = model.imagen + + + + self.model=model + self.is_elucidated = self.model.is_elucidated + # create accelerator instance + + accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) + + assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' + accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') + + self.accelerator = Accelerator(**{ + 'split_batches': split_batches, + 'mixed_precision': accelerator_mixed_precision, + 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] + , **accelerate_kwargs}) + + ImagenTrainer.locked = self.is_distributed + + # cast data to fp16 at training time if needed + + self.cast_half_at_training = accelerator_mixed_precision == 'fp16' + + # grad scaler must be managed outside of accelerator + + grad_scaler_enabled = fp16 + + self.num_unets = len(self.imagen.unets) + + self.use_ema = use_ema and self.is_main + self.ema_unets = nn.ModuleList([]) + + # keep track of what unet is being trained on + # only going to allow 1 unet training at a time + + self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on + + # data related functions + + self.train_dl_iter = None + self.train_dl = None + + self.valid_dl_iter = None + self.valid_dl = None + + self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names + + # auto splitting validation from training, if dataset is passed in + + self.split_valid_from_train = split_valid_from_train + + assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' + self.split_valid_fraction = split_valid_fraction + self.split_random_seed = split_random_seed + + # be able to finely customize learning rate, weight decay + # per unet + + lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) + + for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): + optimizer = Adam( + unet.parameters(), + lr = unet_lr, + eps = unet_eps, + betas = (beta1, beta2), + **kwargs + ) + + if self.use_ema: + self.ema_unets.append(EMA(unet, **ema_kwargs)) + + scaler = GradScaler(enabled = grad_scaler_enabled) + + scheduler = warmup_scheduler = None + + if exists(unet_cosine_decay_max_steps): + scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) + + if exists(unet_warmup_steps): + warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) + + if not exists(scheduler): + scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) + + # set on object + + setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers + setattr(self, f'scaler{ind}', scaler) + setattr(self, f'scheduler{ind}', scheduler) + setattr(self, f'warmup{ind}', warmup_scheduler) + + # gradient clipping if needed + + self.max_grad_norm = max_grad_norm + + # step tracker and misc + + self.register_buffer('steps', torch.tensor([0] * self.num_unets)) + + self.verbose = verbose + + # automatic set devices based on what accelerator decided + + self.imagen.to(self.device) + self.to(self.device) + + # checkpointing + + assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) + self.checkpoint_path = checkpoint_path + self.checkpoint_every = checkpoint_every + self.max_checkpoints_keep = max_checkpoints_keep + + self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main + + if exists(checkpoint_path) and self.can_checkpoint: + bucket = url_to_bucket(checkpoint_path) + + if not self.fs.exists(bucket): + self.fs.mkdir(bucket) + + self.load_from_checkpoint_folder() + + # only allowing training for unet + + self.only_train_unet_number = only_train_unet_number + self.validate_and_set_unet_being_trained(only_train_unet_number) + + # computed values + + @property + def device(self): + return self.accelerator.device + + @property + def is_distributed(self): + return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) + + @property + def is_main(self): + return self.accelerator.is_main_process + + @property + def is_local_main(self): + return self.accelerator.is_local_main_process + + @property + def unwrapped_unet(self): + return self.accelerator.unwrap_model(self.unet_being_trained) + + # optimizer helper functions + + def get_lr(self, unet_number): + self.validate_unet_number(unet_number) + unet_index = unet_number - 1 + + optim = getattr(self, f'optim{unet_index}') + + return optim.param_groups[0]['lr'] + + # function for allowing only one unet from being trained at a time + + def validate_and_set_unet_being_trained(self, unet_number = None): + if exists(unet_number): + self.validate_unet_number(unet_number) + + assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' + + self.only_train_unet_number = unet_number + self.imagen.only_train_unet_number = unet_number + + if not exists(unet_number): + return + + self.wrap_unet(unet_number) + + def wrap_unet(self, unet_number): + if hasattr(self, 'one_unet_wrapped'): + return + + unet = self.imagen.get_unet(unet_number) + self.unet_being_trained = self.accelerator.prepare(unet) + unet_index = unet_number - 1 + + optimizer = getattr(self, f'optim{unet_index}') + scheduler = getattr(self, f'scheduler{unet_index}') + + optimizer = self.accelerator.prepare(optimizer) + + if exists(scheduler): + scheduler = self.accelerator.prepare(scheduler) + + setattr(self, f'optim{unet_index}', optimizer) + setattr(self, f'scheduler{unet_index}', scheduler) + + self.one_unet_wrapped = True + + # hacking accelerator due to not having separate gradscaler per optimizer + + def set_accelerator_scaler(self, unet_number): + unet_number = self.validate_unet_number(unet_number) + scaler = getattr(self, f'scaler{unet_number - 1}') + + self.accelerator.scaler = scaler + for optimizer in self.accelerator._optimizers: + optimizer.scaler = scaler + + # helper print + + def print(self, msg): + if not self.is_main: + return + + if not self.verbose: + return + + return self.accelerator.print(msg) + + # validating the unet number + + def validate_unet_number(self, unet_number = None): + if self.num_unets == 1: + unet_number = default(unet_number, 1) + + assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' + return unet_number + + # number of training steps taken + + def num_steps_taken(self, unet_number = None): + if self.num_unets == 1: + unet_number = default(unet_number, 1) + + return self.steps[unet_number - 1].item() + + def print_untrained_unets(self): + print_final_error = False + + for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): + if steps > 0 or isinstance(unet, NullUnet): + continue + + self.print(f'unet {ind + 1} has not been trained') + print_final_error = True + + if print_final_error: + self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') + + # data related functions + + def add_train_dataloader(self, dl = None): + if not exists(dl): + return + + assert not exists(self.train_dl), 'training dataloader was already added' + self.train_dl = self.accelerator.prepare(dl) + + def add_valid_dataloader(self, dl): + if not exists(dl): + return + + assert not exists(self.valid_dl), 'validation dataloader was already added' + self.valid_dl = self.accelerator.prepare(dl) + + def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): + if not exists(ds): + return + + assert not exists(self.train_dl), 'training dataloader was already added' + + valid_ds = None + if self.split_valid_from_train: + train_size = int((1 - self.split_valid_fraction) * len(ds)) + valid_size = len(ds) - train_size + + ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) + self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') + + dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) + self.train_dl = self.accelerator.prepare(dl) + + if not self.split_valid_from_train: + return + + self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) + + def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): + if not exists(ds): + return + + assert not exists(self.valid_dl), 'validation dataloader was already added' + + dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) + self.valid_dl = self.accelerator.prepare(dl) + + def create_train_iter(self): + assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' + + if exists(self.train_dl_iter): + return + + self.train_dl_iter = cycle(self.train_dl) + + def create_valid_iter(self): + assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' + + if exists(self.valid_dl_iter): + return + + self.valid_dl_iter = cycle(self.valid_dl) + + def train_step(self, unet_number = None, **kwargs): + self.create_train_iter() + loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) + self.update(unet_number = unet_number) + return loss + + @torch.no_grad() + @eval_decorator + def valid_step(self, **kwargs): + self.create_valid_iter() + + context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext + + with context(): + loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) + return loss + + def step_with_dl_iter(self, dl_iter, **kwargs): + dl_tuple_output = cast_tuple(next(dl_iter)) + model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) + loss = self.forward(**{**kwargs, **model_input}) + return loss + + # checkpointing functions + + @property + def all_checkpoints_sorted(self): + glob_pattern = os.path.join(self.checkpoint_path, '*.pt') + checkpoints = self.fs.glob(glob_pattern) + sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) + return sorted_checkpoints + + def load_from_checkpoint_folder(self, last_total_steps = -1): + if last_total_steps != -1: + filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') + self.load(filepath) + return + + sorted_checkpoints = self.all_checkpoints_sorted + + if len(sorted_checkpoints) == 0: + self.print(f'no checkpoints found to load from at {self.checkpoint_path}') + return + + last_checkpoint = sorted_checkpoints[0] + self.load(last_checkpoint) + + def save_to_checkpoint_folder(self): + self.accelerator.wait_for_everyone() + + if not self.can_checkpoint: + return + + total_steps = int(self.steps.sum().item()) + filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') + + self.save(filepath) + + if self.max_checkpoints_keep <= 0: + return + + sorted_checkpoints = self.all_checkpoints_sorted + checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] + + for checkpoint in checkpoints_to_discard: + self.fs.rm(checkpoint) + + # saving and loading functions + + def save( + self, + path, + overwrite = True, + without_optim_and_sched = False, + **kwargs + ): + self.accelerator.wait_for_everyone() + + if not self.can_checkpoint: + return + + fs = self.fs + + assert not (fs.exists(path) and not overwrite) + + self.reset_ema_unets_all_one_device() + + save_obj = dict( + model = self.imagen.state_dict(), + version = __version__, + steps = self.steps.cpu(), + **kwargs + ) + + save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() + + for ind in save_optim_and_sched_iter: + scaler_key = f'scaler{ind}' + optimizer_key = f'optim{ind}' + scheduler_key = f'scheduler{ind}' + warmup_scheduler_key = f'warmup{ind}' + + scaler = getattr(self, scaler_key) + optimizer = getattr(self, optimizer_key) + scheduler = getattr(self, scheduler_key) + warmup_scheduler = getattr(self, warmup_scheduler_key) + + if exists(scheduler): + save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} + + if exists(warmup_scheduler): + save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} + + save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} + + if self.use_ema: + save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} + + # determine if imagen config is available + + if hasattr(self.imagen, '_config'): + self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""') + + save_obj = { + **save_obj, + 'imagen_type': 'elucidated' if self.is_elucidated else 'original', + 'imagen_params': self.imagen._config + } + + #save to path + + with fs.open(path, 'wb') as f: + torch.save(save_obj, f) + + self.print(f'checkpoint saved to {path}') + + def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): + fs = self.fs + + if noop_if_not_exist and not fs.exists(path): + self.print(f'trainer checkpoint not found at {str(path)}') + return + + assert fs.exists(path), f'{path} does not exist' + + self.reset_ema_unets_all_one_device() + + # to avoid extra GPU memory usage in main process when using Accelerate + + with fs.open(path) as f: + loaded_obj = torch.load(f, map_location='cpu') + + if version.parse(__version__) != version.parse(loaded_obj['version']): + self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') + + try: + self.imagen.load_state_dict(loaded_obj['model'], strict = strict) + except RuntimeError: + print("Failed loading state dict. Trying partial load") + self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), + loaded_obj['model'])) + + if only_model: + return loaded_obj + + self.steps.copy_(loaded_obj['steps']) + + for ind in range(0, self.num_unets): + scaler_key = f'scaler{ind}' + optimizer_key = f'optim{ind}' + scheduler_key = f'scheduler{ind}' + warmup_scheduler_key = f'warmup{ind}' + + scaler = getattr(self, scaler_key) + optimizer = getattr(self, optimizer_key) + scheduler = getattr(self, scheduler_key) + warmup_scheduler = getattr(self, warmup_scheduler_key) + + if exists(scheduler) and scheduler_key in loaded_obj: + scheduler.load_state_dict(loaded_obj[scheduler_key]) + + if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: + warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) + + if exists(optimizer): + try: + optimizer.load_state_dict(loaded_obj[optimizer_key]) + scaler.load_state_dict(loaded_obj[scaler_key]) + except: + self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') + + if self.use_ema: + assert 'ema' in loaded_obj + try: + self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) + except RuntimeError: + print("Failed loading state dict. Trying partial load") + self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), + loaded_obj['ema'])) + + self.print(f'checkpoint loaded from {path}') + return loaded_obj + + # managing ema unets and their devices + + @property + def unets(self): + return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) + + def get_ema_unet(self, unet_number = None): + if not self.use_ema: + return + + unet_number = self.validate_unet_number(unet_number) + index = unet_number - 1 + + if isinstance(self.unets, nn.ModuleList): + unets_list = [unet for unet in self.ema_unets] + delattr(self, 'ema_unets') + self.ema_unets = unets_list + + if index != self.ema_unet_being_trained_index: + for unet_index, unet in enumerate(self.ema_unets): + unet.to(self.device if unet_index == index else 'cpu') + + self.ema_unet_being_trained_index = index + return self.ema_unets[index] + + def reset_ema_unets_all_one_device(self, device = None): + if not self.use_ema: + return + + device = default(device, self.device) + self.ema_unets = nn.ModuleList([*self.ema_unets]) + self.ema_unets.to(device) + + self.ema_unet_being_trained_index = -1 + + @torch.no_grad() + @contextmanager + def use_ema_unets(self): + if not self.use_ema: + output = yield + return output + + self.reset_ema_unets_all_one_device() + self.imagen.reset_unets_all_one_device() + + self.unets.eval() + + trainable_unets = self.imagen.unets + self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling + + output = yield + + self.imagen.unets = trainable_unets # restore original training unets + + # cast the ema_model unets back to original device + for ema in self.ema_unets: + ema.restore_ema_model_device() + + return output + + def print_unet_devices(self): + self.print('unet devices:') + for i, unet in enumerate(self.imagen.unets): + device = next(unet.parameters()).device + self.print(f'\tunet {i}: {device}') + + if not self.use_ema: + return + + self.print('\nema unet devices:') + for i, ema_unet in enumerate(self.ema_unets): + device = next(ema_unet.parameters()).device + self.print(f'\tema unet {i}: {device}') + + # overriding state dict functions + + def state_dict(self, *args, **kwargs): + self.reset_ema_unets_all_one_device() + return super().state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + self.reset_ema_unets_all_one_device() + return super().load_state_dict(*args, **kwargs) + + # encoding text functions + + def encode_text(self, text, **kwargs): + return self.imagen.encode_text(text, **kwargs) + + # forwarding functions and gradient step updates + + def update(self, unet_number = None): + unet_number = self.validate_unet_number(unet_number) + self.validate_and_set_unet_being_trained(unet_number) + self.set_accelerator_scaler(unet_number) + + index = unet_number - 1 + unet = self.unet_being_trained + + optimizer = getattr(self, f'optim{index}') + scaler = getattr(self, f'scaler{index}') + scheduler = getattr(self, f'scheduler{index}') + warmup_scheduler = getattr(self, f'warmup{index}') + + # set the grad scaler on the accelerator, since we are managing one per u-net + + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) + + optimizer.step() + optimizer.zero_grad() + + if self.use_ema: + ema_unet = self.get_ema_unet(unet_number) + ema_unet.update() + + # scheduler, if needed + + maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() + + with maybe_warmup_context: + if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs + scheduler.step() + + self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) + + if not exists(self.checkpoint_path): + return + + total_steps = int(self.steps.sum().item()) + + if total_steps % self.checkpoint_every: + return + + self.save_to_checkpoint_folder() + + @torch.no_grad() + @cast_torch_tensor + @imagen_sample_in_chunks + def sample(self, *args, **kwargs): + context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets + + self.print_untrained_unets() + + if not self.is_main: + kwargs['use_tqdm'] = False + + with context(): + output = self.imagen.sample(*args, device = self.device, **kwargs) + + return output + + @partial(cast_torch_tensor, cast_fp16 = True) + def forward( + self, + *args, + unet_number = None, + max_batch_size = None, + **kwargs + ): + unet_number = self.validate_unet_number(unet_number) + self.validate_and_set_unet_being_trained(unet_number) + self.set_accelerator_scaler(unet_number) + + assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' + + total_loss = 0. + + + # + for debug + print('args: ', len(args)) + print('args in:',args[0].shape) + print('kwargs: ', kwargs.keys()) + for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): + # + for debug + print('chun_frac: ', chunk_size_frac) + print('chun_args: ', chunked_args) + print('chun_kwargs: ', chunked_kwargs) + + with self.accelerator.autocast(): + loss = self.model(*chunked_args, unet_number = unet_number, **chunked_kwargs) + loss = loss * chunk_size_frac + + print('loss: ', loss) + + total_loss += loss#.item() + + if self.training: + self.accelerator.backward(loss) + + print('I am here') + return total_loss + + +def write_fasta (sequence, filename_out): + + with open (filename_out, mode ='w') as f: + f.write (f'>{filename_out}\n') + f.write (f'{sequence}') + + +# +def sample_sequence_FromModelB ( + model, + X=None, #this is the target conventionally when using text embd + flag=0, + cond_scales=1., + foldproteins=False, + X_string=None, + x_data=None, + skip_steps=0, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, + # ++++++++++++++++++++++++ + ynormfac=1, + train_unet_number=1, + tokenizer_X=None, + Xnormfac=1., + max_length=1., + prefix=None, + tokenizer_y=None, + ): + steps=0 + e=flag + + + + + #num_samples = min (num_samples,y_train_batch.shape[0] ) + if X!=None: + print (f"Producing {len(X)} samples...from text conditioning X...") + lenn_val=len(X) + if X_string!=None: + lenn_val=len(X_string) + print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") + + if x_data!=None: + print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") + lenn_val=len(x_data) + print (x_data) + + print ('Device: ', device) + + + for iisample in range (lenn_val): + X_cond=None + if X_string==None and X != None: #only do if X provided + X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + if X_string !=None: + X = tokenizer_X.texts_to_sequences(X_string[iisample]) + X= sequence.pad_sequences(X, maxlen=max_length, padding='post', truncating='post') + X=np.array(X) + X_cond=torch.from_numpy(X).float()/Xnormfac + print ('Tokenized and processed: ', X_cond) + + print ("X_cond=", X_cond) + + result=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=x_data, skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images,device=device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + ) + result=torch.round(result*ynormfac) + + plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted') + #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') + plt.legend() + + outname = prefix+ f"sampled_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" + #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") + plt.savefig(outname, dpi=200) + plt.show () + + to_rev=result[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + print (to_rev.shape) + y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + ### reverse second structure input.... + if X_cond != None: + X_cond=torch.round(X_cond*Xnormfac) + + to_rev=X_cond[:,:] + to_rev=to_rev.long().cpu().detach().numpy() + print (to_rev.shape) + X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + if x_data !=None: + X_data_reversed=x_data #is already in sequence fromat.. + + + print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample]) + if foldproteins: + + if X_cond != None: + xbc=X_cond[iisample,:].cpu().detach().numpy() + out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' + if x_data !=None: + #xbc=x_data[iisample] + out_nam=x_data[iisample] + + + tempname='temp' + pdb_file=foldandsavePDB ( + sequence=y_data_reversed[0], + filename_out=tempname, + num_cycle=num_cycle, + flag=flag, + # +++++++++++++++++++ + prefix=prefix + ) + + out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta' + + write_fasta (y_data_reversed[0], out_nam_fasta) + + + out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' + # print('Debug 1: out: ', out_nam) + # print('Debug 2: in: ', pdb_file) + shutil.copy (pdb_file, out_nam) #source, dest + # cmd_line = 'cp ' + pdb_file + ' ' + out_nam + # print(cmd_line) + # os.popen(cmd_line) + # print('Debug 3') + pdb_file=out_nam + + + + + + print (f"Properly named PDB file produced: {pdb_file}") + #flag=1000 + view=show_pdb(pdb_file=pdb_file, flag=flag, + show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) + view.show() + + + steps=steps+1 + + return pdb_file + +def sample_loop_FromModelB (model, + train_loader, + cond_scales=[7.5], #list of cond scales - each sampled... + num_samples=2, #how many samples produced every time tested..... + timesteps=100, + flag=0,foldproteins=False, + use_text_embedd=True,skip_steps=0, + # +++++++++++++++++++ + train_unet_number=1, + ynormfac=1, + prefix=None, + tokenizer_y=None, + Xnormfac=1, + tokenizer_X=None, + + ): + steps=0 + e=flag + for item in train_loader: + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + GT=y_train_batch.cpu().detach() + + GT= GT.unsqueeze(1) + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + + print ('X_train_batch shape: ', X_train_batch.shape) + + for iisample in range (len (cond_scales)): + + if use_text_embedd: + result=model.sample (x= X_train_batch,stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps) + else: + result=model.sample (x= None, x_data_tokenized= X_train_batch, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample],device=device,skip_steps=skip_steps) + + result=torch.round(result*ynormfac) + GT=torch.round (GT*ynormfac) + + for samples in range (num_samples): + print ("sample ", samples, "out of ", num_samples) + + plt.plot (result[samples,0,:].cpu().detach().numpy(),label= f'Predicted') + plt.plot (GT[samples,0,:],label= f'GT {0}') + plt.legend() + + outname = prefix+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + + plt.savefig(outname, dpi=200) + plt.show () + + #reverse y sequence + to_rev=result[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + + y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + #reverse GT_y sequence + to_rev=GT[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + + GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") + + ### reverse second structure input.... + to_rev=torch.round (X_train_batch[:,:]*Xnormfac) + to_rev=to_rev.long().cpu().detach().numpy() + + X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + + print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, predicted sequence: ", y_data_reversed[samples]) + print (f"Ground truth: {GT_y_data_reversed[samples]}") + + if foldproteins: + xbc=X_train_batch[samples,:].cpu().detach().numpy() + out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) + tempname='temp' + pdb_file=foldandsavePDB ( + sequence=y_data_reversed[samples], + filename_out=tempname, + num_cycle=16, flag=flag, + # +++++++++++++++++++ + prefix=prefix + ) + + #out_nam=f'{prefix}{out_nam}.pdb' + out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' + print (f'Original PDB: {pdb_file} OUT: {out_nam}') + shutil.copy (pdb_file, out_nam) #source, dest + pdb_file=out_nam + print (f"Properly named PDB file produced: {pdb_file}") + + view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) + view.show() + + steps=steps+1 + if steps>num_samples: + break +# ++ +def sample_sequence_omegafold_pLM_ModelB ( + model, + X=None, #this is the target conventionally when using text embd + flag=0, + cond_scales=1., + foldproteins=False, + X_string=None, + x_data=None, + skip_steps=0, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, + # ++++++++++++++++++++++++ + ynormfac=1, + train_unet_number=1, + tokenizer_X=None, + Xnormfac=1., + max_length=1., + prefix=None, + tokenizer_y=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True + # ++ + pLM_Model=None, # pLM_Model, + pLM_Model_Name=None, # pLM_Model_Name, + image_channels=None, # image_channels, + pLM_alphabet=None, # esm_alphabet, +): + + # steps=0 + # e=flag + + #num_samples = min (num_samples,y_train_batch.shape[0] ) + if X!=None: + print (f"Producing {len(X)} samples...from text conditioning X...") + lenn_val=len(X) + if X_string!=None: + lenn_val=len(X_string) + print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") + + if x_data!=None: + print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") + lenn_val=len(x_data) + print (x_data) + + print ('Device: ', device) + + pdb_file_list=[] + fasta_file_list=[] + + # + for debug + print('tot ', lenn_val) + for iisample in range (lenn_val): + print("Working on ", iisample) + X_cond=None # this is for text-conditioning + if X_string==None and X != None: #only do if X provided + X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + if X_string !=None: + XX = tokenizer_X.texts_to_sequences(X_string[iisample]) + XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') + XX=np.array(XX) + X_cond=torch.from_numpy(XX).float()/Xnormfac + print ('Tokenized and processed: ', X_cond) + + print ("X_cond=", X_cond) + + # # -- + # result=model.sample ( + # x=X_cond, + # stop_at_unet_number=train_unet_number , + # cond_scale=cond_scales , + # x_data=x_data[iisample], + # # ++ + # x_data_tokenized= + # skip_steps=skip_steps, + # inpaint_images = inpaint_images, + # inpaint_masks = inpaint_masks, + # inpaint_resample_times = inpaint_resample_times, + # init_images = init_images,device=device, + # # ++++++++++++++++++++++++++ + # tokenizer_X=tokenizer_X, + # Xnormfac=Xnormfac, + # max_length=max_length, + # ) + # ++ + # use cond_image as the conditioning, via x_data_tokenized channel + + # ----------------------------------------------------------------- + # for below, two branches are all for cond_img, not for text_cond + if tokenizer_X!=None: + # for SecStr+ModelB + result_embedding=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=x_data[iisample], # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim + # ++ + x_data_tokenized=None, + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images,device=device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + ) + else: + # for ForcPath+ModelB: + # for model.sample() here using x_data_tokenized channel + # + x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac) + x_data_tokenized=x_data_tokenized.to(torch.float) + # here, only one input list is read in + x_data_tokenized=x_data_tokenized.unsqueeze(0) # [batch=1, seq_len] + # leave channel expansion for the self.sample() to handle + + # + for debug: + if CKeys['Debug_TrainerPack']==3: + print("x_data_tokenized dim: ", x_data_tokenized.shape) + print("x_data_tokenized dtype: ", x_data_tokenized.dtype) + print("test x_data_tokenized!=None: ", x_data_tokenized!=None) + + result_embedding=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=None, + # ++ + x_data_tokenized=x_data_tokenized, + # + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images,device=device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + ) + + # # -- + # result=torch.round(result*ynormfac) # (batch=1, channel=1, seq_len) + + # ++ for pLM + # full record + # result_embedding as image.dim: [batch, channels, seq_len] + # result_tokens.dim: [batch, seq_len] + if image_channels==33: + result_tokens,result_logits = convert_into_tokens_using_prob( + result_embedding, + pLM_Model_Name, + ) + else: + result_tokens,result_logits = convert_into_tokens( + pLM_Model, + result_embedding, + pLM_Model_Name, + ) + # +++++++++++++++++++++++++++++++++ + result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] + + # + for debug + print('result dim: ', result.shape) + + # plot sequence token code: esm (33 tokens) + fig=plt.figure() + plt.plot ( + result[0,0,:].cpu().detach().numpy(), + label= f'Predicted' + ) + #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') + plt.legend() + outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" + #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") + if IF_showfig==1: + plt.show () + else: + plt.savefig(outname, dpi=200) + plt.close() + + # + # # -- + # to_rev=result[:,0,:] + # to_rev=to_rev.long().cpu().detach().numpy() + # print (to_rev.shape) + # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + # + # for iii in range (len(y_data_reversed)): + # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + # # ++: for model A: no mask is provided from input + # # reverse the PREDICTED y into a foldable sequence + # # save this block for Model A + # to_rev=result[:,0,:] # token (batch,seq_len) + # y_data_reversed=decode_many_ems_token_rec_for_folding( + # to_rev, + # result_logits, + # pLM_alphabet, + # pLM_Model, + # ) + + # if CKeys['Debug_TrainerPack']==3: + # print("on foldable result: ", to_rev[0]) + # print("on result_logits: ", result_logits[0]) + # a = decode_one_ems_token_rec_for_folding( + # to_rev[0], + # result_logits[0], + # pLM_alphabet, + # pLM_Model, + # ) + # print('One resu: ', a) + # print("on y_data_reversed: ", y_data_reversed[0]) + # print("y_data_reversed.type", y_data_reversed.dtype) + # + + # ++: for model B: using mask from the input + # extract the mask/seq_len from input if possible + if tokenizer_X!=None: + # for SecStr+ModelB + result_mask = read_mask_from_input( + tokenized_data=None, + mask_value=None, + seq_data=x_data[iisample], + max_seq_length=max_length, + ) + else: + # for ForcPath+ModelB + result_mask = read_mask_from_input( + tokenized_data=x_data_tokenized, # None, + mask_value=0, # None, + seq_data=None, # x_data[iisample], + max_seq_length=None, # max_length, + ) + + to_rev=result[:,0,:] # token (batch,seq_len) + if CKeys['Debug_TrainerPack']==3: + print("on foldable result: ", to_rev[0]) + print("on result_logits: ", result_logits[0]) + print("on mask: ", result_mask[0]) + a = decode_one_ems_token_rec_for_folding_with_mask( + to_rev[0], + result_logits[0], + pLM_alphabet, + pLM_Model, + result_mask[0], + ) + print('One resu: ', a) + + y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( + to_rev, + result_logits, + pLM_alphabet, + pLM_Model, + result_mask, + ) + if CKeys['Debug_TrainerPack']==3: + print("on y_data_reversed[0]: ", y_data_reversed[0]) + + + + ### reverse second structure input.... + if X_cond != None: + X_cond=torch.round(X_cond*Xnormfac) + + to_rev=X_cond[:,:] + to_rev=to_rev.long().cpu().detach().numpy() + print (to_rev.shape) + X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + if x_data !=None: + X_data_reversed=x_data #is already in sequence fromat.. + + + # print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence", y_data_reversed[iisample]) + print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed) + + # + for debug + print("================================================") + print("foldproteins: ", foldproteins) + + if not foldproteins: + pdb_file=None + + else: + + if X_cond != None: + xbc=X_cond[iisample,:].cpu().detach().numpy() + out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' + if x_data !=None: + #xbc=x_data[iisample] + # ---------------------------------- + # this one can be too long for a name + out_nam=x_data[iisample] + # ++++++++++++++++++++++++++++++++++ + # + out_nam=iisample + + + tempname='temp' + pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( + sequence=y_data_reversed[0], + filename_out=tempname, + num_cycle=num_cycle, + flag=flag, + # +++++++++++++++++++ + # prefix=prefix, + prefix=sample_dir, + ) + + # out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta' + # ------------------------------------------ + # this one can be too long for a name + # out_nam_fasta=f'{sample_dir}{out_nam}_{flag}_{e}_{iisample}.fasta' + # ++++++++++++++++++++++++++++++++++++++++++ + out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' + + write_fasta (y_data_reversed[0], out_nam_fasta) + + + # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' + # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' + # ------------------------------------------- + # this one can be too long for a name + # However, the input X is recorded in the code + # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb' + # +++++++++++++++++++++++++++++++++++++++++++ + out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' + out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' + + # print('Debug 1: out: ', out_nam) + # print('Debug 2: in: ', pdb_file) + shutil.copy (pdb_file, out_nam) #source, dest + shutil.copy (fasta_file, out_nam_fasta) + # cmd_line = 'cp ' + pdb_file + ' ' + out_nam + # print(cmd_line) + # os.popen(cmd_line) + # print('Debug 3') + # clean the slade to avoid mistakenly using the previous fasta file + os.remove (pdb_file) + os.remove (fasta_file) + + pdb_file=out_nam + fasta_file=out_nam_fasta + pdb_file_list.append(pdb_file) + fasta_file_list.append(fasta_file) + + + print (f"Properly named PDB file produced: {pdb_file}") + if IF_showfig==1: + #flag=1000 + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() + + + # steps=steps+1 + + return pdb_file_list, fasta_file_list + +# +def sample_sequence_omegafold_ModelB ( + model, + X=None, #this is the target conventionally when using text embd + flag=0, + cond_scales=1., + foldproteins=False, + X_string=None, + x_data=None, + skip_steps=0, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, + # ++++++++++++++++++++++++ + ynormfac=1, + train_unet_number=1, + tokenizer_X=None, + Xnormfac=1., + max_length=1., + prefix=None, + tokenizer_y=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True +): + + # steps=0 + # e=flag + + #num_samples = min (num_samples,y_train_batch.shape[0] ) + if X!=None: + print (f"Producing {len(X)} samples...from text conditioning X...") + lenn_val=len(X) + if X_string!=None: + lenn_val=len(X_string) + print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") + + if x_data!=None: + print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") + lenn_val=len(x_data) + print (x_data) + + print ('Device: ', device) + + # + for debug + print('tot ', lenn_val) + for iisample in range (lenn_val): + print("Working on ", iisample) + X_cond=None + if X_string==None and X != None: #only do if X provided + X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + if X_string !=None: + XX = tokenizer_X.texts_to_sequences(X_string[iisample]) + XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') + XX=np.array(XX) + X_cond=torch.from_numpy(XX).float()/Xnormfac + print ('Tokenized and processed: ', X_cond) + + print ("X_cond=", X_cond) + + # # -- + # result=model.sample ( + # x=X_cond, + # stop_at_unet_number=train_unet_number , + # cond_scale=cond_scales , + # x_data=x_data[iisample], + # # ++ + # x_data_tokenized= + # skip_steps=skip_steps, + # inpaint_images = inpaint_images, + # inpaint_masks = inpaint_masks, + # inpaint_resample_times = inpaint_resample_times, + # init_images = init_images,device=device, + # # ++++++++++++++++++++++++++ + # tokenizer_X=tokenizer_X, + # Xnormfac=Xnormfac, + # max_length=max_length, + # ) + # ++ + if tokenizer_X!=None: + result=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=x_data[iisample], + # ++ + x_data_tokenized=None, + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images,device=device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + ) + else: + x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac) + x_data_tokenized=x_data_tokenized.to(torch.float) + # + for debug: + if CKeys['Debug_TrainerPack']==1: + print("x_data_tokenized dim: ", x_data_tokenized.shape) + print("x_data_tokenized dtype: ", x_data_tokenized.dtype) + print("test: ", x_data_tokenized!=None) + result=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=None, + # ++ + x_data_tokenized=x_data_tokenized, + # + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images,device=device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + ) + + + result=torch.round(result*ynormfac) + # + for debug + print('result dim: ', result.shape) + + fig=plt.figure() + plt.plot ( + result[0,0,:].cpu().detach().numpy(), + label= f'Predicted' + ) + #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') + plt.legend() + outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" + #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") + if IF_showfig==1: + plt.show () + else: + plt.savefig(outname, dpi=200) + plt.close() + + + to_rev=result[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + print (to_rev.shape) + y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + ### reverse second structure input.... + if X_cond != None: + X_cond=torch.round(X_cond*Xnormfac) + + to_rev=X_cond[:,:] + to_rev=to_rev.long().cpu().detach().numpy() + print (to_rev.shape) + X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + if x_data !=None: + X_data_reversed=x_data #is already in sequence fromat.. + + + # print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence", y_data_reversed[iisample]) + print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed) + + # + for debug + print("================================================") + print("foldproteins: ", foldproteins) + + if not foldproteins: + pdb_file=None + + else: + + if X_cond != None: + xbc=X_cond[iisample,:].cpu().detach().numpy() + out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' + if x_data !=None: + #xbc=x_data[iisample] + # ---------------------------------- + # this one can be too long for a name + out_nam=x_data[iisample] + # ++++++++++++++++++++++++++++++++++ + # + out_nam=iisample + + + tempname='temp' + pdb_file=foldandsavePDB ( + sequence=y_data_reversed[0], + filename_out=tempname, + num_cycle=num_cycle, + flag=flag, + # +++++++++++++++++++ + # prefix=prefix, + prefix=sample_dir, + ) + + # out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta' + # ------------------------------------------ + # this one can be too long for a name + # out_nam_fasta=f'{sample_dir}{out_nam}_{flag}_{e}_{iisample}.fasta' + # ++++++++++++++++++++++++++++++++++++++++++ + out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' + + write_fasta (y_data_reversed[0], out_nam_fasta) + + + # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' + # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' + # ------------------------------------------- + # this one can be too long for a name + # However, the input X is recorded in the code + # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb' + # +++++++++++++++++++++++++++++++++++++++++++ + out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' + + # print('Debug 1: out: ', out_nam) + # print('Debug 2: in: ', pdb_file) + shutil.copy (pdb_file, out_nam) #source, dest + # cmd_line = 'cp ' + pdb_file + ' ' + out_nam + # print(cmd_line) + # os.popen(cmd_line) + # print('Debug 3') + pdb_file=out_nam + + + + print (f"Properly named PDB file produced: {pdb_file}") + if IF_showfig==1: + #flag=1000 + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() + + + # steps=steps+1 + + return pdb_file + +# ++ for de novo input of ForcPath +# ++ +def sample_sequence_omegafold_pLM_ModelB_For_ForcPath ( + model, + X=None, #this is the target conventionally when using text embd + flag=0, + cond_scales=[1.], # 1., + foldproteins=False, + X_string=None, + x_data=None, + skip_steps=0, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, + # ++++++++++++++++++++++++ + ynormfac=1, + train_unet_number=1, + tokenizer_X=None, + Xnormfac=1., + max_length=1., + prefix=None, + tokenizer_y=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True + # ++ + pLM_Model=None, # pLM_Model, + pLM_Model_Name=None, # pLM_Model_Name, + image_channels=None, # image_channels, + pLM_alphabet=None, # esm_alphabet, +): + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # prepare input in different channels + # + X_cond=None # this is for text-conditioning + if X_string==None and X != None: #only do if X provided + print (f"Producing {len(X)} samples...from text conditioning X...") + lenn_val=len(X) + # shape of X: [[..],[..]]: double bracket + X_cond=torch.Tensor(X).to(device) + # -- + # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + if X_string !=None: + print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") + lenn_val=len(X_string) + # -- + XX = tokenizer_X.texts_to_sequences(X_string[iisample]) + # ++ + XX = tokenizer_X.texts_to_sequences(X_string) + XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') + XX=np.array(XX) + X_cond=torch.from_numpy(XX).float()/Xnormfac + print ('Tokenized and processed: ', X_cond) + + if x_data!=None: + print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") + lenn_val=len(x_data) + if tokenizer_X==None: # for ForcPath, + # need to do Padding and Normalization + # and then put into tokenized data channel + x_data_tokenized=[] + for ii in range(lenn_val): + x_data_one_line=pad_a_np_arr(x_data[ii], 0.0, max_length) + x_data_tokenized.append(x_data_one_line) + x_data_tokenized=np.array(x_data_tokenized) + x_data_tokenized=torch.from_numpy(x_data_tokenized/Xnormfac) + else: + # leave for SecStr case: TBA + pass + # print (x_data) + # ++ for result_mask based on input: x_data or x_data_tokenized + # ++: for model B: using mask from the input + # extract the mask/seq_len from input if possible + if tokenizer_X!=None: + # for SecStr+ModelB + result_mask = read_mask_from_input( + tokenized_data=None, + mask_value=None, + seq_data=x_data, # x_data[iisample], + max_seq_length=max_length, + ) + else: + # for ForcPath+ModelB + result_mask = read_mask_from_input( + tokenized_data=x_data_tokenized, # None, + mask_value=0, # None, + seq_data=None, # x_data[iisample], + max_seq_length=None, # max_length, + ) + + + print ("Input contents:") + print ("cond_img condition: x_data=\n", x_data) + print ("Text condition: X_cond=\n", X_cond) + + # store the results + pdb_file_list=[] + fasta_file_list=[] + + # loop over cond_scales + for idx_cond, this_cond_scale in enumerate(cond_scales): + print(f"Working on cond_scale {str(this_cond_scale)}") + # do sampling + # ----------------------------------------------------------------- + # for below, two branches are all for cond_img, not for text_cond + if tokenizer_X!=None: + # for SecStr+ModelB, not test here + result_embedding=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=this_cond_scale, # cond_scales , + x_data=x_data, # x_data[iisample], # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim + # ++ + x_data_tokenized=None, + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images,device=device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + ) + else: + # for ForcPath+ModelB: + # for model.sample() here using x_data_tokenized channel + x_data_tokenized=x_data_tokenized.to(torch.float) # shape [batch, max_seq_len] + # leave channel expansion for the self.sample() to handle + + # + for debug: + if CKeys['Debug_TrainerPack']==3: + print("x_data_tokenized dim: ", x_data_tokenized.shape) + print("x_data_tokenized dtype: ", x_data_tokenized.dtype) + print("test x_data_tokenized!=None: ", x_data_tokenized!=None) + + result_embedding=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=this_cond_scale, # cond_scales , + x_data=None, + # ++ + x_data_tokenized=x_data_tokenized, + # + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images,device=device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + ) + + # handle the results: from embedding into AA + # ++ for pLM + if image_channels==33: + # pass + result_tokens,result_logits = convert_into_tokens_using_prob( + result_embedding, + pLM_Model_Name, + ) + else: + # full record + # result_embedding as image.dim: [batch, channels, seq_len] + # result_tokens.dim: [batch, seq_len] + result_tokens,result_logits = convert_into_tokens( + pLM_Model, + result_embedding, + pLM_Model_Name, + ) + # +++++++++++++++++++++++++++++++++ + result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] + + # + for debug + print('result dim: ', result.shape) + + # plot sequence token code: esm (33 tokens), for one batch + fig=plt.figure() + for ii in range(lenn_val): + plt.plot ( + result[ii,0,:].cpu().detach().numpy(), + label= f'Predicted for Input#{str(ii)}' + ) + #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') + plt.legend() + outname = sample_dir+ f"DenovoInputXs_CondScale_No{str(idx_cond)}_Val_{str(this_cond_scale)}_{e}_{steps}.jpg" + #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") + if IF_showfig==1: + plt.show () + else: + plt.savefig(outname, dpi=200) + plt.close() + + # translate result into AA + to_rev=result[:,0,:] # token (batch,seq_len) + if CKeys['Debug_TrainerPack']==3: + print("on foldable result: ", to_rev[0]) + print("on result_logits: ", result_logits[0]) + print("on mask: ", result_mask[0]) + a = decode_one_ems_token_rec_for_folding_with_mask( + to_rev[0], + result_logits[0], + pLM_alphabet, + pLM_Model, + result_mask[0], + ) + print('One resu: ', a) + + y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( + to_rev, + result_logits, + pLM_alphabet, + pLM_Model, + result_mask, + ) + if CKeys['Debug_TrainerPack']==3: + print("on y_data_reversed[0]: ", y_data_reversed[0]) + + ### reverse second structure input.... + if X_cond != None: + X_cond=torch.round(X_cond*Xnormfac) + + to_rev=X_cond[:,:] + to_rev=to_rev.long().cpu().detach().numpy() + print (to_rev.shape) + X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + if x_data !=None: + # work for second structure input.... + # work for ForcPath input... + X_data_reversed=x_data #is already in sequence fromat.. + + # sections for each one result + for iisample in range(lenn_val): + print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample]) + + out_nam_fasta=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.fasta' + write_fasta (y_data_reversed[iisample], out_nam_fasta) + fasta_file_list.append(out_nam_fasta) + + # + for debug + print("================================================") + print("foldproteins: ", foldproteins) + + if not foldproteins: + pdb_file=None + + else: + + if X_cond != None: + # not maintained + xbc=X_cond[iisample,:].cpu().detach().numpy() + out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})+f'_{flag}_{steps}' + if x_data !=None: + pass + # #xbc=x_data[iisample] + # # ---------------------------------- + # # this one can be too long for a name + # out_nam=x_data[iisample] + # # ++++++++++++++++++++++++++++++++++ + # # + # out_nam=iisample + + tempname='temp' + pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( + sequence=y_data_reversed[iisample], + filename_out=tempname, + num_cycle=num_cycle, + flag=flag, + # +++++++++++++++++++ + # prefix=prefix, + prefix=sample_dir, + ) + + + # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' + # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' + # ------------------------------------------- + # this one can be too long for a name + # However, the input X is recorded in the code + # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb' + # +++++++++++++++++++++++++++++++++++++++++++ + out_nam=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.pdb' + # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' + + # print('Debug 1: out: ', out_nam) + # print('Debug 2: in: ', pdb_file) + shutil.copy (pdb_file, out_nam) #source, dest + # shutil.copy (fasta_file, out_nam_fasta) + # cmd_line = 'cp ' + pdb_file + ' ' + out_nam + # print(cmd_line) + # os.popen(cmd_line) + # print('Debug 3') + # clean the slade to avoid mistakenly using the previous fasta file + os.remove (pdb_file) + os.remove (fasta_file) + + pdb_file=out_nam + # fasta_file=out_nam_fasta + pdb_file_list.append(pdb_file) + + # ++ write the input condtion as a reference: for ForcPath + out_nam_inX=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}_input.txt' + if torch.is_tensor(X_data_reversed[iisample]): + # for safety, not used usually + xbc=X_data_reversed[iisample].cpu().detach().numpy() + else: + xbc=X_data_reversed[iisample] + if tokenizer_X==None: + # for ForcPath case + out_inX=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) + else: + # for SecStr case + out_inX=xbc + with open(out_nam_inX, "w") as inX_file: + inX_file.write(out_inX) + + + print (f"Properly named PDB file produced: {pdb_file}") + if IF_showfig==1: + #flag=1000 + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() + + + return pdb_file_list, fasta_file_list + +# ++ +def sample_sequence_pLM_ModelB_For_ForcPath_Predictor ( + model, + X=None, #this is the target conventionally when using text embd + flag=0, + cond_scales=[1.], # 1., + foldproteins=False, + X_string=None, + x_data=None, + skip_steps=0, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, + # ++++++++++++++++++++++++ + ynormfac=1, + train_unet_number=1, + tokenizer_X=None, + Xnormfac=1., + max_length=1., + prefix=None, + tokenizer_y=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True + # ++ + pLM_Model=None, # pLM_Model, + pLM_Model_Name=None, # pLM_Model_Name, + image_channels=None, # image_channels, + pLM_alphabet=None, # esm_alphabet, + # ++ + esm_layer=None, +): + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # input: a list of AA sequence in string format + # output: ForcPath prediction + # + # 1. decide input channel: text_cond or img_cond + X_cond=None # this is for text-conditioning + if X_string==None and X != None: #only do if X provided + print (f"Producing {len(X)} samples...from text conditioning X...") + lenn_val=len(X) + # shape of X: [[..],[..]]: double bracket + X_cond=torch.Tensor(X).to(device) + # -- + # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + if X_string !=None: + print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") + lenn_val=len(X_string) + # -- + XX = tokenizer_X.texts_to_sequences(X_string[iisample]) + # ++ + XX = tokenizer_X.texts_to_sequences(X_string) + XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') + XX=np.array(XX) + X_cond=torch.from_numpy(XX).float()/Xnormfac + print ('Tokenized and processed: ', X_cond) + + + if x_data!=None: # this is for img_conditioning channel + # Will use this channel + print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") + lenn_val=len(x_data) + seq_len_list=[] + for this_AA in x_data: + seq_len_list.append(len(this_AA)) + + + print ("Input contents:") + print ("cond_img condition: x_data=\n", x_data) + print ("Text condition: X_cond=\n", X_cond) + + # 2. perform sampling + # loop over cond_scales + resu_prediction={} + + for idx_cond, this_cond_scale in enumerate(cond_scales): + print(f"Working on cond_scale {str(this_cond_scale)}") + # leave the translation from seq to tokenized + # in the model.sample function using x_data channel + # Need to pass on the esm model part or tokenizer_X + # + result_embedding=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=this_cond_scale, # cond_scales , + x_data=x_data, # x_data[iisample], # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim + # ++ + x_data_tokenized=None, + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images,device=device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + # ++ for esm + pLM_Model=pLM_Model, + pLM_alphabet=pLM_alphabet, + esm_layer=esm_layer, + pLM_Model_Name=pLM_Model_Name, + # image_channels=image_channels, + ) + + # convert into prediction + # 3. translate prediction into something meaningful + # consider channel average and masking + # average across channels + result_embedding=torch.mean(result_embedding, 1) # (batch, seq_len) + # read mask from input: X_train_batch_picked (batch, seq_len) + # result_mask looks like, 0,1,1,...,1,0,0 + # will fill 0th component be zero + result_mask = read_mask_from_input( + tokenized_data=None, # X_train_batch[:num_samples], + mask_value=0.0, + seq_data=x_data, # None, + max_seq_length=max_length, # None, + ) + # apply mask to result: keep true and zero all false + # this also make sure 0th components are zero ACCIDENTLY + result = result_embedding.cpu()*result_mask # (batch, seq_len) + # result = result.cpu() + y_data_reversed = result*ynormfac + # 4. translate the results into a list + prediction_list = [] + for ii in range(len(x_data)): + prediction_list.append( + y_data_reversed[ii, :seq_len_list[ii]+1] + ) + if CKeys['Debug_TrainerPack']==3: + print("check prediction dim:") + print(f"model output: ", y_data_reversed[0]) + print(f"prediction output: ", prediction_list[0]) + # + # store the results + resu_prediction[str(this_cond_scale)]=prediction_list + + return resu_prediction,seq_len_list + + + + + + + + + +# # -------------------------------------------------------------------------- +# # prepare input in different channels +# # +# X_cond=None # this is for text-conditioning +# if X_string==None and X != None: #only do if X provided +# print (f"Producing {len(X)} samples...from text conditioning X...") +# lenn_val=len(X) +# # shape of X: [[..],[..]]: double bracket +# X_cond=torch.Tensor(X).to(device) +# # -- +# # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) +# if X_string !=None: +# print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") +# lenn_val=len(X_string) +# # -- +# XX = tokenizer_X.texts_to_sequences(X_string[iisample]) +# # ++ +# XX = tokenizer_X.texts_to_sequences(X_string) +# XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') +# XX=np.array(XX) +# X_cond=torch.from_numpy(XX).float()/Xnormfac +# print ('Tokenized and processed: ', X_cond) + +# if x_data!=None: +# print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") +# lenn_val=len(x_data) +# if tokenizer_X==None: # for ForcPath, +# # need to do Padding and Normalization +# # and then put into tokenized data channel +# x_data_tokenized=[] +# for ii in range(lenn_val): +# x_data_one_line=pad_a_np_arr(x_data[ii], 0.0, max_length) +# x_data_tokenized.append(x_data_one_line) +# x_data_tokenized=np.array(x_data_tokenized) +# x_data_tokenized=torch.from_numpy(x_data_tokenized/Xnormfac) +# else: +# # leave for SecStr case: TBA +# pass +# # print (x_data) +# # ++ for result_mask based on input: x_data or x_data_tokenized +# # ++: for model B: using mask from the input +# # extract the mask/seq_len from input if possible +# if tokenizer_X!=None: +# # for SecStr+ModelB +# result_mask = read_mask_from_input( +# tokenized_data=None, +# mask_value=None, +# seq_data=x_data, # x_data[iisample], +# max_seq_length=max_length, +# ) +# else: +# # for ForcPath+ModelB +# result_mask = read_mask_from_input( +# tokenized_data=x_data_tokenized, # None, +# mask_value=0, # None, +# seq_data=None, # x_data[iisample], +# max_seq_length=None, # max_length, +# ) + + +# print ("Input contents:") +# print ("cond_img condition: x_data=\n", x_data) +# print ("Text condition: X_cond=\n", X_cond) + +# # store the results +# pdb_file_list=[] +# fasta_file_list=[] + +# # loop over cond_scales +# for idx_cond, this_cond_scale in enumerate(cond_scales): +# print(f"Working on cond_scale {str(this_cond_scale)}") +# # do sampling +# # ----------------------------------------------------------------- +# # for below, two branches are all for cond_img, not for text_cond +# if tokenizer_X!=None: +# # for SecStr+ModelB, not test here +# result_embedding=model.sample ( +# x=X_cond, +# stop_at_unet_number=train_unet_number , +# cond_scale=this_cond_scale, # cond_scales , +# x_data=x_data, # x_data[iisample], # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim +# # ++ +# x_data_tokenized=None, +# skip_steps=skip_steps, +# inpaint_images = inpaint_images, +# inpaint_masks = inpaint_masks, +# inpaint_resample_times = inpaint_resample_times, +# init_images = init_images,device=device, +# # ++++++++++++++++++++++++++ +# tokenizer_X=tokenizer_X, +# Xnormfac=Xnormfac, +# max_length=max_length, +# ) +# else: +# # for ForcPath+ModelB: +# # for model.sample() here using x_data_tokenized channel +# x_data_tokenized=x_data_tokenized.to(torch.float) # shape [batch, max_seq_len] +# # leave channel expansion for the self.sample() to handle + +# # + for debug: +# if CKeys['Debug_TrainerPack']==3: +# print("x_data_tokenized dim: ", x_data_tokenized.shape) +# print("x_data_tokenized dtype: ", x_data_tokenized.dtype) +# print("test x_data_tokenized!=None: ", x_data_tokenized!=None) + +# result_embedding=model.sample ( +# x=X_cond, +# stop_at_unet_number=train_unet_number , +# cond_scale=this_cond_scale, # cond_scales , +# x_data=None, +# # ++ +# x_data_tokenized=x_data_tokenized, +# # +# skip_steps=skip_steps, +# inpaint_images = inpaint_images, +# inpaint_masks = inpaint_masks, +# inpaint_resample_times = inpaint_resample_times, +# init_images = init_images,device=device, +# # ++++++++++++++++++++++++++ +# tokenizer_X=tokenizer_X, +# Xnormfac=Xnormfac, +# max_length=max_length, +# ) + +# # handle the results: from embedding into AA +# # ++ for pLM +# # full record +# # result_embedding as image.dim: [batch, channels, seq_len] +# # result_tokens.dim: [batch, seq_len] +# result_tokens,result_logits = convert_into_tokens( +# pLM_Model, +# result_embedding, +# pLM_Model_Name, +# ) +# # +++++++++++++++++++++++++++++++++ +# result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] + +# # + for debug +# print('result dim: ', result.shape) + +# # plot sequence token code: esm (33 tokens), for one batch +# fig=plt.figure() +# for ii in range(lenn_val): +# plt.plot ( +# result[ii,0,:].cpu().detach().numpy(), +# label= f'Predicted for Input#{str(ii)}' +# ) +# #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') +# plt.legend() +# outname = sample_dir+ f"DenovoInputXs_CondScale_No{str(idx_cond)}_Val_{str(this_cond_scale)}_{e}_{steps}.jpg" +# #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") +# if IF_showfig==1: +# plt.show () +# else: +# plt.savefig(outname, dpi=200) +# plt.close() + +# # translate result into AA +# to_rev=result[:,0,:] # token (batch,seq_len) +# if CKeys['Debug_TrainerPack']==3: +# print("on foldable result: ", to_rev[0]) +# print("on result_logits: ", result_logits[0]) +# print("on mask: ", result_mask[0]) +# a = decode_one_ems_token_rec_for_folding_with_mask( +# to_rev[0], +# result_logits[0], +# pLM_alphabet, +# pLM_Model, +# result_mask[0], +# ) +# print('One resu: ', a) + +# y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( +# to_rev, +# result_logits, +# pLM_alphabet, +# pLM_Model, +# result_mask, +# ) +# if CKeys['Debug_TrainerPack']==3: +# print("on y_data_reversed[0]: ", y_data_reversed[0]) + +# ### reverse second structure input.... +# if X_cond != None: +# X_cond=torch.round(X_cond*Xnormfac) + +# to_rev=X_cond[:,:] +# to_rev=to_rev.long().cpu().detach().numpy() +# print (to_rev.shape) +# X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + +# for iii in range (len(y_data_reversed)): +# X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") +# if x_data !=None: +# # work for second structure input.... +# # work for ForcPath input... +# X_data_reversed=x_data #is already in sequence fromat.. + +# # sections for each one result +# for iisample in range(lenn_val): +# print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample]) + +# out_nam_fasta=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.fasta' +# write_fasta (y_data_reversed[iisample], out_nam_fasta) +# fasta_file_list.append(out_nam_fasta) + +# # + for debug +# print("================================================") +# print("foldproteins: ", foldproteins) + +# if not foldproteins: +# pdb_file=None + +# else: + +# if X_cond != None: +# # not maintained +# xbc=X_cond[iisample,:].cpu().detach().numpy() +# out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})+f'_{flag}_{steps}' +# if x_data !=None: +# pass +# # #xbc=x_data[iisample] +# # # ---------------------------------- +# # # this one can be too long for a name +# # out_nam=x_data[iisample] +# # # ++++++++++++++++++++++++++++++++++ +# # # +# # out_nam=iisample + +# tempname='temp' +# pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( +# sequence=y_data_reversed[iisample], +# filename_out=tempname, +# num_cycle=num_cycle, +# flag=flag, +# # +++++++++++++++++++ +# # prefix=prefix, +# prefix=sample_dir, +# ) + + +# # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' +# # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' +# # ------------------------------------------- +# # this one can be too long for a name +# # However, the input X is recorded in the code +# # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb' +# # +++++++++++++++++++++++++++++++++++++++++++ +# out_nam=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.pdb' +# # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' + +# # print('Debug 1: out: ', out_nam) +# # print('Debug 2: in: ', pdb_file) +# shutil.copy (pdb_file, out_nam) #source, dest +# # shutil.copy (fasta_file, out_nam_fasta) +# # cmd_line = 'cp ' + pdb_file + ' ' + out_nam +# # print(cmd_line) +# # os.popen(cmd_line) +# # print('Debug 3') +# # clean the slade to avoid mistakenly using the previous fasta file +# os.remove (pdb_file) +# os.remove (fasta_file) + +# pdb_file=out_nam +# # fasta_file=out_nam_fasta +# pdb_file_list.append(pdb_file) + +# # ++ write the input condtion as a reference: for ForcPath +# out_nam_inX=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}_input.txt' +# if torch.is_tensor(X_data_reversed[iisample]): +# # for safety, not used usually +# xbc=X_data_reversed[iisample].cpu().detach().numpy() +# else: +# xbc=X_data_reversed[iisample] +# if tokenizer_X==None: +# # for ForcPath case +# out_inX=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) +# else: +# # for SecStr case +# out_inX=xbc +# with open(out_nam_inX, "w") as inX_file: +# inX_file.write(out_inX) + + +# print (f"Properly named PDB file produced: {pdb_file}") +# if IF_showfig==1: +# #flag=1000 +# view=show_pdb( +# pdb_file=pdb_file, +# flag=flag, +# show_sidechains=show_sidechains, +# show_mainchains=show_mainchains, +# color=color +# ) +# view.show() + + +# return pdb_file_list, fasta_file_list + +# ++ +# for ProteinDesigner +def sample_loop_omegafold_pLM_ModelB ( + model, + train_loader, + cond_scales=[7.5], #list of cond scales - each sampled... + num_samples=2, #how many samples produced every time tested..... + timesteps=100, # not used + flag=0, + foldproteins=False, + use_text_embedd=True, + skip_steps=0, + # +++++++++++++++++++ + train_unet_number=1, + ynormfac=1, + prefix=None, + tokenizer_y=None, + Xnormfac=1, + tokenizer_X=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True + # ++ + pLM_Model=None, + pLM_Model_Name=None, + image_channels=None, + pLM_alphabet=None, +): + # ===================================================== + # sample # = num_samples*(# of mini-batches) + # ===================================================== + # steps=0 + # e=flag + # for item in train_loader: + for idx, item in enumerate(train_loader): + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + # -- + # # ++ for pLM case: + # if pLM_Model_Name=='None': + # # just use the encoded sequence + # # y_train_batch_in = y_train_batch.unsqueeze(1) + # X_train_batch_in = X_train_batch.unsqueeze(1) + # # pass + # elif pLM_Model_Name=='esm2_t33_650M_UR50D': + # # with torch.no_grad(): + # # results = pLM_Model( + # # y_train_batch, + # # repr_layers=[33], + # # return_contacts=False, + # # ) + # # y_train_batch_in = results["representations"][33] + # # y_train_batch_in = rearrange( + # # y_train_batch_in, + # # 'b l c -> b c l' + # # ) + # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) + + + GT=y_train_batch.cpu().detach() + + GT= GT.unsqueeze(1) + if num_samples>y_train_batch.shape[0]: + print("Warning: sampling # > len(mini_batch)") + + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + X_train_batch_picked = X_train_batch[:num_samples,:] # X_train_batch_in[:num_samples ] # + print ('After pLM, (TEST) X_batch shape: ', X_train_batch_picked.shape) + + for iisample in range (len (cond_scales)): + + if use_text_embedd: + result_embedding=model.sample ( + # x= X_train_batch, + x= X_train_batch_picked, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample], + device=device, + skip_steps=skip_steps + ) + else: + result_embedding=model.sample ( + x= None, + # x_data_tokenized= X_train_batch, + x_data_tokenized= X_train_batch_picked, # dim=(batch, seq_len), will extend channels inside .sample(), + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample], + device=device, + skip_steps=skip_steps + ) + # ++ for pLM: + if image_channels==33: + result_tokens,result_logits = convert_into_tokens_using_prob( + result_embedding, + pLM_Model_Name, + ) + else: + # full record + # result_embedding as image.dim: [batch, channels, seq_len] + # result_tokens.dim: [batch, seq_len] + result_tokens,result_logits = convert_into_tokens( + pLM_Model, + result_embedding, + pLM_Model_Name, + ) + + # # --------------------------------- + # result=torch.round(result*ynormfac) + # GT=torch.round (GT*ynormfac) + # +++++++++++++++++++++++++++++++++ + result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] + + # + + # # ------------------------------------------- + # #reverse y sequence + # to_rev=result[:,0,:] + # to_rev=to_rev.long().cpu().detach().numpy() + # + # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + # + # for iii in range (len(y_data_reversed)): + # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + # ++++++++++++++++++++++++++++++++++++++++++++ + # extract the mask/seq_len from input if possible + # here, from dataloader, we only use tokenized_data for mask generation + result_mask = read_mask_from_input( + tokenized_data=X_train_batch[:num_samples], + mask_value=0, + seq_data=None, + max_seq_length=None, + ) + to_rev=result[:,0,:] # token (batch,seq_len) + if CKeys['Debug_TrainerPack']==3: + print("on foldable result: ", to_rev[0]) + print("on result_logits: ", result_logits[0]) + print("on mask: ", result_mask[0]) + a = decode_one_ems_token_rec_for_folding_with_mask( + to_rev[0], + result_logits[0], + pLM_alphabet, + pLM_Model, + result_mask[0], + ) + print('One resu: ', a) + + y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( + to_rev, + result_logits, + pLM_alphabet, + pLM_Model, + result_mask, + ) + if CKeys['Debug_TrainerPack']==3: + print("on y_data_reversed[0]: ", y_data_reversed[0]) + + # # ++++++++++++++++++++++++++++++++++++++++++++ + # # reverse the PREDICTED y into a foldable sequence + # # save this block for Model A + # to_rev=result[:,0,:] # token (batch,seq_len) + # y_data_reversed=decode_many_ems_token_rec_for_folding( + # to_rev, + # result_logits, + # pLM_alphabet, + # pLM_Model, + # ) + # if CKeys['Debug_TrainerPack']==3: + # print("on foldable result: ", to_rev[0]) + # print("on result_logits: ", result_logits[0]) + # a = decode_one_ems_token_rec_for_folding( + # to_rev[0], + # result_logits[0], + # pLM_alphabet, + # pLM_Model, + # ) + # print('One resu: ', a) + # print("on y_data_reversed: ", y_data_reversed[0]) + + + # # ----------------------------------------------------- + # #reverse GT_y sequence + # to_rev=GT[:,0,:] + # to_rev=to_rev.long().cpu().detach().numpy() + # + # GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + # + # for iii in range (len(y_data_reversed)): + # GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + #reverse GT_y sequence + # GT should be SAFE to reverse + to_rev=GT[:,0,:] # (batch,1,seq_len)->(batch, seq_len) + GT_y_data_reversed=decode_many_ems_token_rec( + to_rev, + pLM_alphabet, + ) + + + # -- not for SecStr anymore + # ### reverse second structure input.... + # to_rev=torch.round (X_train_batch[:,:]*Xnormfac) + # to_rev=to_rev.long().cpu().detach().numpy() + # ++ + ### reverse general float input... + to_rev=X_train_batch[:,:]*Xnormfac + to_rev=to_rev.cpu().detach().numpy() + # here, assume X_train_batch is for cond_img: there are padding at both beginning and ending part + # so, first move the 0th padding to the end: + # Note: + # 1. this is good for SecStr case: (not maintained here) + # 2. this is not good for ForcPath, but can be cued in MD postprocess since the first component will always be 0 + n_batch=to_rev.shape[0] + n_embed=to_rev.shape[1] + to_rev_1 = np.zeros(to_rev.shape) + to_rev_1[:,0:n_embed-1]=to_rev[:,1:n_embed] + + # ++ different input + if tokenizer_X!=None: + # change into int + to_rev_1 = np.round(to_rev_1) + X_data_reversed=tokenizer_X.sequences_to_texts (to_rev_1) + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + else: + X_data_reversed=to_rev_1.copy() + # + for debug + if CKeys['Debug_TrainerPack']==1: + print("X_data_reversed: ", X_data_reversed) + + + for samples in range (num_samples): + print ("sample ", samples+1, "out of ", num_samples) + + fig=plt.figure() + plt.plot ( + result[samples,0,:].cpu().detach().numpy(), + label= f'Predicted' + ) + plt.plot ( + GT[samples,0,:], + label= f'GT {0}' + ) + plt.legend() + outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + if IF_showfig==1: + plt.show() + else: + plt.savefig(outname, dpi=200) + plt.close () + +# # # ------------------------------------------- +# # #reverse y sequence +# # to_rev=result[:,0,:] +# # to_rev=to_rev.long().cpu().detach().numpy() +# # +# # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) +# # +# # for iii in range (len(y_data_reversed)): +# # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") +# # ++++++++++++++++++++++++++++++++++++++++++++ +# # extract the mask/seq_len from input if possible +# # here, from dataloader, we only use tokenized_data for mask generation +# result_mask = read_mask_from_input( +# tokenized_data=X_train_batch[:num_samples], +# mask_value=0, +# seq_data=None, +# max_seq_length=None, +# ) +# to_rev=result[:,0,:] # token (batch,seq_len) +# if CKeys['Debug_TrainerPack']==3: +# print("on foldable result: ", to_rev[0]) +# print("on result_logits: ", result_logits[0]) +# print("on mask: ", result_mask[0]) +# a = decode_one_ems_token_rec_for_folding_with_mask( +# to_rev[0], +# result_logits[0], +# pLM_alphabet, +# pLM_Model, +# result_mask[0], +# ) +# print('One resu: ', a) + +# y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( +# to_rev, +# result_logits, +# pLM_alphabet, +# pLM_Model, +# result_mask, +# ) +# if CKeys['Debug_TrainerPack']==3: +# print("on y_data_reversed[0]: ", y_data_reversed[0]) + +# # # ++++++++++++++++++++++++++++++++++++++++++++ +# # # reverse the PREDICTED y into a foldable sequence +# # # save this block for Model A +# # to_rev=result[:,0,:] # token (batch,seq_len) +# # y_data_reversed=decode_many_ems_token_rec_for_folding( +# # to_rev, +# # result_logits, +# # pLM_alphabet, +# # pLM_Model, +# # ) +# # if CKeys['Debug_TrainerPack']==3: +# # print("on foldable result: ", to_rev[0]) +# # print("on result_logits: ", result_logits[0]) +# # a = decode_one_ems_token_rec_for_folding( +# # to_rev[0], +# # result_logits[0], +# # pLM_alphabet, +# # pLM_Model, +# # ) +# # print('One resu: ', a) +# # print("on y_data_reversed: ", y_data_reversed[0]) + + +# # # ----------------------------------------------------- +# # #reverse GT_y sequence +# # to_rev=GT[:,0,:] +# # to_rev=to_rev.long().cpu().detach().numpy() +# # +# # GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) +# # +# # for iii in range (len(y_data_reversed)): +# # GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") +# # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ +# #reverse GT_y sequence +# # GT should be SAFE to reverse +# to_rev=GT[:,0,:] +# GT_y_data_reversed=decode_many_ems_token_rec( +# to_rev, +# pLM_alphabet, +# ) + + +# ### reverse second structure input.... +# to_rev=torch.round (X_train_batch[:,:]*Xnormfac) +# to_rev=to_rev.long().cpu().detach().numpy() +# # here, assume X_train_batch is for cond_img: there are padding at both beginning and ending part +# # so, first move the 0th padding to the end +# n_batch=to_rev.shape[0] +# n_embed=to_rev.shape[1] +# to_rev_1 = np.zeros(to_rev.shape) +# to_rev_1[:,0:n_embed-1]=to_rev[:,1:n_embed] + +# # ++ different input +# if tokenizer_X!=None: +# X_data_reversed=tokenizer_X.sequences_to_texts (to_rev_1) +# for iii in range (len(y_data_reversed)): +# X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") +# else: +# X_data_reversed=to_rev_1.copy() +# # + for debug +# if CKeys['Debug_TrainerPack']==1: +# print("X_data_reversed: ", X_data_reversed) + + + # print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples]) + print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} \nor\n {X_data_reversed[samples]}, ") + print (f"predicted sequence: {y_data_reversed[samples]}") + print (f"Ground truth: {GT_y_data_reversed[samples]}") + error=string_diff (y_data_reversed[samples], GT_y_data_reversed[samples])/len (GT_y_data_reversed[samples]) + print(f"Recovery ratio(Ref): {1.-error}") + + # move some + # # -- X_train_batch is normalized + # xbc=X_train_batch[samples,:].cpu().detach().numpy() + # # out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) + # out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) + # ++ + xbc = X_data_reversed[samples] + if type(xbc)==str: + out_nam_content=xbc + else: + out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) + # 1. write out the input X in the dataloder + out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' + # + write the condition clearly + # X_data_reversed: an array + with open(out_nam_inX, "w") as inX_file: + # inX_file.write(f'{X_data_reversed[samples]}\n') + inX_file.write(out_nam_content) + # 2. write out the predictions + out_nam_OuY_PR=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}_predict.fasta' + with open(out_nam_OuY_PR, "w") as ouY_fasta: + ouY_fasta.write(f">Predicted\n") + ouY_fasta.write(y_data_reversed[samples]) + # 3. Only for dataloader: write out the recovered ground truth + out_nam_OuY_GT=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}_recGT.fasta' + with open(out_nam_OuY_GT, "w") as ouY_fasta: + ouY_fasta.write(f">reconstructed GT, recoverabliblity: {1.-error}\n") + ouY_fasta.write(GT_y_data_reversed[samples]) + + + + if foldproteins: + + tempname='temp' + pdb_file,fasta_file=foldandsavePDB_pdb_fasta ( + sequence=y_data_reversed[samples], + filename_out=tempname, + num_cycle=16, flag=flag, + # +++++++++++++++++++ + prefix=prefix + ) + + # #out_nam=f'{prefix}{out_nam}.pdb' + # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' + # ------------------------------------------------------ + # sometime, this name below can get too long to fit + # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb' + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add a way to save the sampling name and results + # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' + out_nam_seq=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.fasta' + + + if CKeys['Debug_TrainerPack']==1: + print("pdb_file: ", pdb_file) + print("out_nam: ", out_nam) + + print (f'Original PDB: {pdb_file} OUT: {out_nam}') + shutil.copy (pdb_file, out_nam) #source, dest + shutil.copy (fasta_file, out_nam_seq) + + + # clean the slade to avoid mistakenly using the previous fasta file + os.remove (pdb_file) + os.remove (fasta_file) + + + pdb_file=out_nam + print (f"Properly named PDB file produced: {pdb_file}") + print (f"input X for sampling stored: {pdb_file}") + + if IF_showfig==1: + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() +# ++ +# For ProteinPredictor +def sample_loop_omegafold_pLM_ModelB_Predictor ( + model, + train_loader, + cond_scales=[7.5], #list of cond scales - each sampled... + num_samples=2, #how many samples produced every time tested..... + timesteps=100, # not used + flag=0, + foldproteins=False, + use_text_embedd=True, + skip_steps=0, + # +++++++++++++++++++ + train_unet_number=1, + ynormfac=1, + prefix=None, + tokenizer_y=None, + Xnormfac=1, + tokenizer_X=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True + # ++ + pLM_Model=None, + pLM_Model_Name=None, + image_channels=None, + pLM_alphabet=None, + # ++ + esm_layer=None, +): + # ===================================================== + # sample # = num_samples*(# of mini-batches) + # ===================================================== + # steps=0 + # e=flag + # for item in train_loader: + val_epoch_MSE_list=[] + resu_pred = {} + resu_grou = {} + # + for iisample in range (len (cond_scales)): + # calculate loss for one selected cond_scales + # + val_epoch_MSE=0. + num_rec=0 + this_prediction = [] + this_groundtruth = [] + # + for idx, item in enumerate(train_loader): + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # 1. adjust the number of sample to collect in each batch + if num_samples>y_train_batch.shape[0]: + print("Warning: sampling # > len(mini_batch)") + num_samples = min (num_samples,y_train_batch.shape[0]) + print (f"Producing {num_samples} samples...") + X_train_batch_picked = X_train_batch[:num_samples,:] # X_train_batch_in[:num_samples ] # + GT=y_train_batch.cpu().detach() + GT_picked = GT[:num_samples,:] + # GT_picked = GT_picked.unsqueeze(1) + + # 2. prepare if pLM is used at the input end: + # this is done inised model.sample fun via x_data_tokenized channel + # + # 3. sample inside the loop of cond_scales + if use_text_embedd: + result_embedding=model.sample ( + # x= X_train_batch, + x= X_train_batch_picked, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample], + device=device, + skip_steps=skip_steps, + # ++ + pLM_Model_Name=pLM_Model_Name, + pLM_Model=pLM_Model, + pLM_alphabet=pLM_alphabet, + esm_layer=esm_layer, + ) + else: + result_embedding=model.sample ( + x= None, + # x_data_tokenized= X_train_batch, + x_data_tokenized= X_train_batch_picked, # dim=(batch, seq_len), will extend channels inside .sample(), + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample], + device=device, + skip_steps=skip_steps, + # ++ + pLM_Model_Name=pLM_Model_Name, + pLM_Model=pLM_Model, + pLM_alphabet=pLM_alphabet, + esm_layer=esm_layer, + ) + # result_embedding as image.dim: [batch, channels, seq_len] + # + # 4. translate prediction into something meaningful + # consider channel average and masking + # average across channels + result_embedding=torch.mean(result_embedding, 1) # (batch, seq_len) + # read mask from input: X_train_batch_picked (batch, seq_len) + # result_mask looks like, 0,1,1,...,1,0,0 + # will fill 0th component be zero + result_mask = read_mask_from_input( + tokenized_data=X_train_batch[:num_samples], + mask_value=0.0, + seq_data=None, + max_seq_length=None, + ) + # apply mask to result: keep true and zero all false + result = result_embedding*result_mask # (batch, seq_len) + result = result.cpu() + # result = result.unsqueeze(1) # (batch, 1, seq_len) + # this is ONLY the result from the model, not predictio yet + # + # 5. calculate loss + with torch.no_grad(): + val_loss_MSE = criterion_MSE_sum( + result, + GT_picked, + ) + val_epoch_MSE += val_loss_MSE.item()/GT_picked.shape[1] + num_rec += len(GT_picked) + # + # 6. convert into prediction + y_data_reversed = result*ynormfac + # prepare GT + GT_y_data_reversed = GT_picked*ynormfac + # accumulate the results + for ibat in range (GT_picked.shape[0]): + this_prediction.append (np.array( y_data_reversed[ibat,:].cpu() )) + this_groundtruth.append (np.array( GT_y_data_reversed[ibat,:].cpu() )) + + # + # 5. reverse input to AA sequence... if needed + # TBA + + # for one scal_cond + # summarize the loss + TestSet_MSE = val_epoch_MSE/num_rec + resu_pred[str(cond_scales[iisample])] = this_prediction + resu_grou[str(cond_scales[iisample])] = this_groundtruth + + # store the MSE along cond_scales + val_epoch_MSE_list.append(TestSet_MSE) + + return val_epoch_MSE_list, resu_pred, resu_grou + + + +# ++++++++++++++++++++++++++++++++++++++++++++++++ +def sample_loop_omegafold_ModelB ( + model, + train_loader, + cond_scales=[7.5], #list of cond scales - each sampled... + num_samples=2, #how many samples produced every time tested..... + timesteps=100, # not used + flag=0, + foldproteins=False, + use_text_embedd=True, + skip_steps=0, + # +++++++++++++++++++ + train_unet_number=1, + ynormfac=1, + prefix=None, + tokenizer_y=None, + Xnormfac=1, + tokenizer_X=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True +): + # ===================================================== + # sample # = num_samples*(# of mini-batches) + # ===================================================== + # steps=0 + # e=flag + # for item in train_loader: + for idx, item in enumerate(train_loader): + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + GT=y_train_batch.cpu().detach() + + GT= GT.unsqueeze(1) + if num_samples>y_train_batch.shape[0]: + print("Warning: sampling # > len(mini_batch)") + + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + X_train_batch_picked = X_train_batch[:num_samples,:] + print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) + + for iisample in range (len (cond_scales)): + + if use_text_embedd: + result=model.sample ( + # x= X_train_batch, + x= X_train_batch_picked, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample], + device=device, + skip_steps=skip_steps + ) + else: + result=model.sample ( + x= None, + # x_data_tokenized= X_train_batch, + x_data_tokenized= X_train_batch_picked, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample], + device=device, + skip_steps=skip_steps + ) + + result=torch.round(result*ynormfac) + GT=torch.round (GT*ynormfac) + + for samples in range (num_samples): + print ("sample ", samples+1, "out of ", num_samples) + + fig=plt.figure() + plt.plot ( + result[samples,0,:].cpu().detach().numpy(), + label= f'Predicted' + ) + plt.plot ( + GT[samples,0,:], + label= f'GT {0}' + ) + plt.legend() + outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + if IF_showfig==1: + plt.show() + else: + plt.savefig(outname, dpi=200) + plt.close () + + #reverse y sequence + to_rev=result[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + + y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + #reverse GT_y sequence + to_rev=GT[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + + GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") + + ### reverse second structure input.... + to_rev=torch.round (X_train_batch[:,:]*Xnormfac) + to_rev=to_rev.long().cpu().detach().numpy() + + # ++ different input + if tokenizer_X!=None: + X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + else: + X_data_reversed=to_rev.copy() + # + for debug + if CKeys['Debug_TrainerPack']==1: + print("X_data_reversed: ", X_data_reversed) + + + print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples]) + print (f"Ground truth: {GT_y_data_reversed[samples]}") + + if foldproteins: + xbc=X_train_batch[samples,:].cpu().detach().numpy() + out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) + tempname='temp' + pdb_file=foldandsavePDB ( + sequence=y_data_reversed[samples], + filename_out=tempname, + num_cycle=16, flag=flag, + # +++++++++++++++++++ + prefix=prefix + ) + + # #out_nam=f'{prefix}{out_nam}.pdb' + # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' + # ------------------------------------------------------ + # sometime, this name below can get too long to fit + # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb' + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add a way to save the sampling name and results + # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' + out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' + + if CKeys['Debug_TrainerPack']==1: + print("pdb_file: ", pdb_file) + print("out_nam: ", out_nam) + + print (f'Original PDB: {pdb_file} OUT: {out_nam}') + shutil.copy (pdb_file, out_nam) #source, dest + # + + with open(out_nam_inX, "w") as inX_file: + inX_file.write(f'{X_data_reversed[samples]}\n') + + pdb_file=out_nam + print (f"Properly named PDB file produced: {pdb_file}") + print (f"input X for sampling stored: {pdb_file}") + + if IF_showfig==1: + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() + +# steps=steps+1 + +# if steps>num_samples: +# break + +# +# +def sample_loop_FromModelB (model, + train_loader, + cond_scales=[7.5], #list of cond scales - each sampled... + num_samples=2, #how many samples produced every time tested..... + timesteps=100, + flag=0,foldproteins=False, + use_text_embedd=True,skip_steps=0, + # +++++++++++++++++++ + train_unet_number=1, + ynormfac=1, + prefix=None, + tokenizer_y=None, + Xnormfac=1, + tokenizer_X=None, + + ): + steps=0 + e=flag + for item in train_loader: + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + GT=y_train_batch.cpu().detach() + + GT= GT.unsqueeze(1) + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + + print ('X_train_batch shape: ', X_train_batch.shape) + + for iisample in range (len (cond_scales)): + + if use_text_embedd: + result=model.sample (x= X_train_batch,stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps) + else: + result=model.sample (x= None, x_data_tokenized= X_train_batch, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales[iisample],device=device,skip_steps=skip_steps) + + result=torch.round(result*ynormfac) + GT=torch.round (GT*ynormfac) + + for samples in range (num_samples): + print ("sample ", samples, "out of ", num_samples) + + plt.plot (result[samples,0,:].cpu().detach().numpy(),label= f'Predicted') + plt.plot (GT[samples,0,:],label= f'GT {0}') + plt.legend() + + outname = prefix+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + + plt.savefig(outname, dpi=200) + plt.show () + + #reverse y sequence + to_rev=result[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + + y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + #reverse GT_y sequence + to_rev=GT[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + + GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") + + ### reverse second structure input.... + to_rev=torch.round (X_train_batch[:,:]*Xnormfac) + to_rev=to_rev.long().cpu().detach().numpy() + + X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + + print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, predicted sequence: ", y_data_reversed[samples]) + print (f"Ground truth: {GT_y_data_reversed[samples]}") + + if foldproteins: + xbc=X_train_batch[samples,:].cpu().detach().numpy() + out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) + tempname='temp' + pdb_file=foldandsavePDB ( + sequence=y_data_reversed[samples], + filename_out=tempname, + num_cycle=16, flag=flag, + # +++++++++++++++++++ + prefix=prefix + ) + + #out_nam=f'{prefix}{out_nam}.pdb' + out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' + print (f'Original PDB: {pdb_file} OUT: {out_nam}') + shutil.copy (pdb_file, out_nam) #source, dest + pdb_file=out_nam + print (f"Properly named PDB file produced: {pdb_file}") + + view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) + view.show() + + steps=steps+1 + if steps>num_samples: + break + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +# train_loop tasks: +# 1. calculate loss for one batch +# 2. call sample loop +# 3. call sample sequence +# 4. print records and save model +# =============================================================== +# for ProteinDesigner_B +# 1. expanded for Probability case +# ++ +cal_norm_prob = nn.Softmax(dim=2) + +def train_loop_Model_B ( + model, + train_loader, + test_loader, + # + optimizer=None, + print_every=10, + epochs= 300, + start_ep=0, + start_step=0, + train_unet_number=1, + print_loss_every_steps=1000, + # + trainer=None, + plot_unscaled=False, + max_batch_size=4, + save_model=False, + cond_scales=[7.5], #list of cond scales - each sampled... + num_samples=2, #how many samples produced every time tested..... + foldproteins=False, + cond_image=False, #use cond_images... + # add some + # +++++++++++++++++++++++++++ + device=None, + loss_list=[], + epoch_list=[], + train_hist_file=None, + train_hist_file_full=None, + prefix=None, + Xnormfac=1., + ynormfac=1., + tokenizer_X=None, + tokenizer_y=None, + test_condition_list=[], + max_length=1, + CKeys=None, + sample_steps=1, + sample_dir=None, + save_every_epoch=1, + save_point_info_file=None, + store_dir=None, + # ++ + pLM_Model_Name=None, + image_channels=None, + print_error=False, +): + + if not exists (trainer): + if not exists (optimizer): + print ("ERROR: If trainer not used, need to provide optimizer.") + if exists (trainer): + print ("Trainer provided... will be used") + + # steps=start_step+1 + # # + + # added_steps=0+1 + # + steps=start_step + added_steps=0 + + loss_total=0 + + # ++ for pLM + if pLM_Model_Name=='None': + pLM_Model=None + + elif pLM_Model_Name=='esm2_t33_650M_UR50D': + # dim: 1280 + esm_layer=33 + pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() + len_toks=len(esm_alphabet.all_toks) + pLM_Model.eval() + pLM_Model. to(device) + + elif pLM_Model_Name=='esm2_t36_3B_UR50D': + # dim: 2560 + esm_layer=36 + pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() + len_toks=len(esm_alphabet.all_toks) + pLM_Model.eval() + pLM_Model. to(device) + + elif pLM_Model_Name=='esm2_t30_150M_UR50D': + # dim: 640 + esm_layer=30 + pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() + len_toks=len(esm_alphabet.all_toks) + pLM_Model.eval() + pLM_Model. to(device) + + elif pLM_Model_Name=='esm2_t12_35M_UR50D': + # dim: 480 + esm_layer=12 + pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() + len_toks=len(esm_alphabet.all_toks) + pLM_Model.eval() + pLM_Model. to(device) + + else: + print("pLM model is missing...") + + + for e in range(1, epochs+1): + + # start = time.time() + + torch.cuda.empty_cache() + print ("######################################################################################") + start = time.time() + print ("NOW: Training epoch: ", e+start_ep) + + # TRAINING + train_epoch_loss = 0 + model.train() + + print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") + + for item in train_loader: + steps += 1 + added_steps += 1 + + X_train_batch= item[0].to(device) + y_train_batch= item[1].to(device) + # project y_ into embedding space + if CKeys["Debug_TrainerPack"]==1: + print("Initial unload the dataloader items: ...") + print("X_train_batch.dim: ", X_train_batch.shape) + print("y_train_batch.dim: ", y_train_batch.shape) + + # --------------------------------------------------------- + # prepare for model.forward() to calculate loss + # --------------------------------------------------------- + # # -- + # if pLM_Model_Name=='None': + # # just use the encoded sequence + # y_train_batch_in = y_train_batch.unsqueeze(1) + # X_train_batch_in = X_train_batch.unsqueeze(1) + # # pass + # elif pLM_Model_Name=='esm2_t33_650M_UR50D': + # with torch.no_grad(): + # results = pLM_Model( + # y_train_batch, + # repr_layers=[33], + # return_contacts=False, + # ) + # y_train_batch_in = results["representations"][33] # (batch, seq_len, channels) + # y_train_batch_in = rearrange( + # y_train_batch_in, + # 'b l c -> b c l' + # ) + # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) + # else: + # print(f"Required pLM name is not defined!!") + # ++ + if pLM_Model_Name=='None': + # just use the encoded sequence + y_train_batch_in = y_train_batch.unsqueeze(1) + X_train_batch_in = X_train_batch.unsqueeze(1) + # pass + else: # assume ESM models + with torch.no_grad(): + results = pLM_Model( + y_train_batch, + repr_layers=[esm_layer], + return_contacts=False, + ) + y_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels) + # ++ for Probability case + if image_channels==33: + with torch.no_grad(): + # calculate logits: (batch, seq_len, 33) + y_train_batch_in = pLM_Model.lm_head( + y_train_batch_in + ) + # normalize to get (0,1) probability + y_train_batch_in = cal_norm_prob(y_train_batch_in) + + # switch the dimension -> (batch, channel, seq_len) + y_train_batch_in = rearrange( + y_train_batch_in, + 'b l c -> b c l' + ) + + + X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) + + + # + for debug + if CKeys["Debug_TrainerPack"]==1: + print("After pLM model, the shape of X and y for training:") + print("X_train_batch_in.dim: ", X_train_batch_in.shape) + print("y_train_batch_in.dim: ", y_train_batch_in.shape) + + + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if exists (trainer): + + if cond_image==False: + loss = trainer( + y_train_batch.unsqueeze(1) , # true image (batch, channels, seq_len) + x=X_train_batch, # tokenized text (batch, ) + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + # # ---------------------------------------------------------- + # if cond_image==True: + # loss = trainer( + # y_train_batch.unsqueeze(1) , # true image + # x=None, # tokenized text + # cond_images=X_train_batch.unsqueeze(1), # cond_image + # unet_number=train_unet_number, + # max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + # ) + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if cond_image==True: + loss = trainer( + y_train_batch_in, # true image + x=None, # tokenized text + cond_images=X_train_batch_in, # cond_image + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + + trainer.update(unet_number = train_unet_number) + + else: + optimizer.zero_grad() + if cond_image==False: + loss=model ( + y_train_batch.unsqueeze(1) , + x=X_train_batch, + unet_number=train_unet_number + ) + # # ------------------------------------------------------ + # if cond_image==True: + # loss=model ( + # y_train_batch.unsqueeze(1) , + # x=None, + # cond_images=X_train_batch.unsqueeze(1), + # unet_number=train_unet_number + # ) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if cond_image==True: + loss=model ( + y_train_batch_in , + x=None, + cond_images=X_train_batch_in, + unet_number=train_unet_number + ) + # + loss.backward( ) + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + + loss_total=loss_total+loss.item() + # + + train_epoch_loss=train_epoch_loss+loss.item() + + if steps % print_every == 0: + # for progress bar + print(".", end="") + + # if steps>0: + if added_steps>0: + + if steps % print_loss_every_steps == 0: + # + for debug + if CKeys['Debug_TrainerPack']==2: + print('I am here') + print("Here is steps: ", steps) + + norm_loss=loss_total/print_loss_every_steps + print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add a line to the hist file + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' + with open(train_hist_file,'a') as f: + f.write(add_line) + + + loss_list.append (norm_loss) + loss_total=0 + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + epoch_list.append(e+start_ep) + + # fig = plt.figure(figsize=(12,8),dpi=200) + fig = plt.figure() + plt.plot (epoch_list, loss_list, label='Loss') + plt.legend() + + # outname = prefix+ f"loss_{e+start_ep}_{steps}.jpg" + outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" + # + # the order, save then show, matters + if CKeys['SlientRun']==1: + plt.savefig(outname, dpi=200) + else: + plt.show() + plt.close(fig) + # plt.close() + + if added_steps>0: + # if steps>0: + if steps % sample_steps == 0: + # + for debug + if CKeys['Debug_TrainerPack']==2: + print('I am here') + print("Here is steps: ", steps) + + if plot_unscaled: + #test before scaling... + plt.plot ( + y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), + label= 'Unscaled GT' + ) + plt.legend() + plt.show() + +# # -------------------------------------------------- +# GT=y_train_batch.cpu().detach() + +# GT=resize_image_to( +# GT.unsqueeze(1), +# model.imagen.image_sizes[train_unet_number-1], +# ) + + + #### + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + print ("I. SAMPLING IN TEST SET: ") + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + #### + # num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + + + if cond_image == True: + use_text_embedd=False + # - + # cond_scales_extended=[1. for i in range(num_samples)] + # + + cond_scales_extended=cond_scales + else: + use_text_embedd=True + cond_scales_extended=cond_scales + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + sample_loop_omegafold_pLM_ModelB ( + model, + test_loader, + cond_scales=cond_scales_extended, # cond_scales,# #list of cond scales - each sampled... + num_samples=num_samples, #how many samples produced every time tested..... + timesteps=64, + flag=steps, + #reverse=False, + foldproteins=foldproteins, + use_text_embedd= use_text_embedd, + # ++++++++++++++++++++ + train_unet_number=train_unet_number, + ynormfac=ynormfac, + prefix=prefix, + tokenizer_y=tokenizer_y, + Xnormfac=Xnormfac, + tokenizer_X=tokenizer_X, + # ++ + # ++ + CKeys=CKeys, + sample_dir=sample_dir, + steps=steps, + e=e+start_ep, + IF_showfig= CKeys['SlientRun']!=1, + # ++ + pLM_Model=pLM_Model, + pLM_Model_Name=pLM_Model_Name, + image_channels=image_channels, + pLM_alphabet=esm_alphabet, + ) + + + #--------------------------------------------------------- + # sample_loop ( + # model, + # test_loader, + # cond_scales=cond_scales,# #list of cond scales - each sampled... + # num_samples=num_samples, #how many samples produced every time tested..... + # timesteps=64, + # flag=steps, + # #reverse=False, + # foldproteins=foldproteins, + # use_text_embedd= use_text_embedd, + # # ++++++++++++++++++++ + # train_unet_number=train_unet_number, + # ynormfac=ynormfac, + # prefix=prefix, + # tokenizer_y=tokenizer_y, + # Xnormfac=Xnormfac, + # tokenizer_X=tokenizer_X, + # ) + + #index_word': '{"1": "~", "2": "h", "3": "e", "4": "s", "5": "t", "6": "g", "7": "b", "8": "i"}', + #'word_index': '{"~": 1, "h": 2, "e": 3, "s": 4, "t": 5, "g": 6, "b": 7, "i": 8}'} + + AH_code=2/Xnormfac + BS_code=3/Xnormfac + unstr_code= 1/Xnormfac + + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + print ("II. SAMPLING FOR DE NOVO:") + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + + # +++++++++++++++++++++++++++++++++++++++++ + DeNovoSam_pdbs, fasta_file_list=\ + sample_sequence_omegafold_pLM_ModelB ( + model, + x_data=test_condition_list, + flag=steps, # flag="DeNovo", # , + cond_scales=1., + foldproteins=foldproteins, + # ++++++++++ + ynormfac=ynormfac, + train_unet_number=train_unet_number, + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + prefix=prefix, + tokenizer_y=tokenizer_y, + # ++ + CKeys=CKeys, + sample_dir=sample_dir, + steps=steps, + e=e+start_ep, + IF_showfig= CKeys['SlientRun']!=1, + # ++ + pLM_Model=pLM_Model, + pLM_Model_Name=pLM_Model_Name, + image_channels=image_channels, + pLM_alphabet=esm_alphabet, + ) + + if print_error and len(DeNovoSam_pdbs)>0: + print("Calculate SecStr and design error:") + # + for ii in range(len(test_condition_list)): + seq=test_condition_list[ii][0] + DSSPresult,_,sequence_res=get_DSSP_result(DeNovoSam_pdbs[ii]) + print (f"INPUT: {seq}\nRESULT: {DSSPresult}\nAA sequence: {sequence_res}") + error=string_diff (DSSPresult, seq)/len (seq) + print ("Error: ", error) + + + + # # +++++++++++++++++++++++++++++++++++++++++ + # sample_sequence_omegafold_ModelB ( + # model, + # x_data=test_condition_list, + # flag=steps, # flag="DeNovo", # , + # cond_scales=1., + # foldproteins=foldproteins, + # # ++++++++++ + # ynormfac=ynormfac, + # train_unet_number=train_unet_number, + # tokenizer_X=tokenizer_X, + # Xnormfac=Xnormfac, + # max_length=max_length, + # prefix=prefix, + # tokenizer_y=tokenizer_y, + # # ++ + # CKeys=CKeys, + # sample_dir=sample_dir, + # steps=steps, + # e=e+start_ep, + # IF_showfig= CKeys['SlientRun']!=1, + # ) + + # for this_x_data in test_condition_list: + # sample_sequence_omegafold ( + # model, + # x_data=this_x_data, + # flag=steps, # flag="DeNovo", # , + # cond_scales=1., + # foldproteins=True, + # # ++++++++++ + # ynormfac=ynormfac, + # train_unet_number=train_unet_number, + # tokenizer_X=tokenizer_X, + # Xnormfac=Xnormfac, + # max_length=max_length, + # prefix=prefix, + # tokenizer_y=tokenizer_y, + # # ++ + # CKeys=CKeys, + # sample_dir=sample_dir, + # steps=steps, + # e=e+start_ep, + # IF_showfig= CKeys['SlientRun']!=1, + # ) + # + # model, + # X=None, #this is the target conventionally when using text embd + # flag=0, + # cond_scales=1., + # foldproteins=False, + # X_string=None, + # x_data=None, + # skip_steps=0, + # inpaint_images = None, + # inpaint_masks = None, + # inpaint_resample_times = None, + # init_images = None, + # num_cycle=16, + # # ++++++++++++++++++++++++ + # ynormfac=1, + # train_unet_number=1, + # tokenizer_X=None, + # Xnormfac=1., + # max_length=1., + # prefix=None, + # tokenizer_y=None, + # # ++ + # CKeys=None, + # sample_dir=None, + + # ----------------------------------------- + # sample_sequence ( + # model, + # x_data=['~~~HHHHHHHHHHHHHHH~~'], + # flag=steps,cond_scales=1., + # foldproteins=True, + # # ++++++++++ + # ynormfac=ynormfac, + # ) + # sample_sequence ( + # model, + # x_data=['~~~HHHHHHHHHHHHHHH~~~~HHHHHHHHHHHHHH~~~'], + # flag=steps,cond_scales=1., + # foldproteins=True, + # # ++++++++++ + # ynormfac=ynormfac, + # ) + # sample_sequence ( + # model, + # x_data=['~~EEESSTTS~SEEEEEEEEE~SBS~EEEEEE~~'], + # flag=steps,cond_scales=1., + # foldproteins=True, + # # ++++++++++++ + # ynormfac=ynormfac, + # ) + + # if steps>0: + # # -------------------------------------------------------------------- + # if added_steps>0: + # if save_model and steps % print_loss_every_steps==0: + # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" + # trainer.save(fname) + # print (f"Model saved: ", fname) + # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" + # torch.save(model.state_dict(), fname) + # print (f"Statedict model saved: ", fname) + + # steps=steps+1 + # added_steps += 1 + + # every epoch: + norm_loss_over_e = train_epoch_loss/len(train_loader) + print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # write this into "train_hist_file_full" + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' + with open(train_hist_file_full,'a') as f: + f.write(add_line) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # save model every this epoches + if save_model and (e+start_ep) % save_every_epoch==0 and e>1: + # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" + fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" + trainer.save(fname) + print (f"Model saved: ", fname) + # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" + fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" + torch.save(model.state_dict(), fname) + print (f"Statedict model saved: ", fname) + # add a saving point file + top_line='epoch,steps,norm_loss'+'\n' + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' + with open(save_point_info_file, "w") as f: + f.write(top_line) + f.write(add_line) + + + + print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") + +# =============================================================== +# for ProteinPredictor_B +# +def train_loop_Model_B_Predictor ( + model, + train_loader, + test_loader, + # + optimizer=None, + print_every=10, + epochs= 300, + start_ep=0, + start_step=0, + train_unet_number=1, + print_loss_every_steps=1000, + # + trainer=None, + plot_unscaled=False, + max_batch_size=4, + save_model=False, + cond_scales=[1.], #list of cond scales - each sampled... + num_samples=2, #how many samples produced every time tested..... + foldproteins=False, + cond_image=False, #use cond_images... + # add some + # +++++++++++++++++++++++++++ + device=None, + loss_list=[], + epoch_list=[], + train_hist_file=None, + train_hist_file_full=None, + prefix=None, + Xnormfac=1., + ynormfac=1., + tokenizer_X=None, + tokenizer_y=None, + test_condition_list=[], + max_length=1, + CKeys=None, + sample_steps=1, + sample_dir=None, + save_every_epoch=1, + save_point_info_file=None, + store_dir=None, + # ++ + pLM_Model_Name=None, + image_channels=None, + print_error=False, + # ++ + train_hist_file_on_testset=None, +): + + if not exists (trainer): + if not exists (optimizer): + print ("ERROR: If trainer not used, need to provide optimizer.") + if exists (trainer): + print ("Trainer provided... will be used") + + # steps=start_step+1 + # # + + # added_steps=0+1 + # + steps=start_step + added_steps=0 + + loss_total=0 + + # ++ for pLM +# # -- +# if pLM_Model_Name=='trivial': +# pLM_Model=None + +# elif pLM_Model_Name=='esm2_t33_650M_UR50D': +# # dim: 1280 +# esm_layer=33 +# pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() +# len_toks=len(esm_alphabet.all_toks) +# pLM_Model.eval() +# pLM_Model. to(device) + +# elif pLM_Model_Name=='esm2_t36_3B_UR50D': +# # dim: 2560 +# esm_layer=36 +# pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() +# len_toks=len(esm_alphabet.all_toks) +# pLM_Model.eval() +# pLM_Model. to(device) + +# elif pLM_Model_Name=='esm2_t30_150M_UR50D': +# # dim: 640 +# esm_layer=30 +# pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() +# len_toks=len(esm_alphabet.all_toks) +# pLM_Model.eval() +# pLM_Model. to(device) + +# elif pLM_Model_Name=='esm2_t12_35M_UR50D': +# # dim: 480 +# esm_layer=12 +# pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() +# len_toks=len(esm_alphabet.all_toks) +# pLM_Model.eval() +# pLM_Model. to(device) + +# else: +# print("pLM model is missing...") + # ++ + pLM_Model, esm_alphabet, \ + esm_layer, len_toks = load_in_pLM( + pLM_Model_Name, + device, + ) + + + for e in range(1, epochs+1): + + # start = time.time() + + torch.cuda.empty_cache() + print ("######################################################################################") + start = time.time() + print ("NOW: Training epoch: ", e+start_ep) + + # TRAINING + train_epoch_loss = 0 + model.train() + + print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") + + for item in train_loader: + steps += 1 + added_steps += 1 + + X_train_batch= item[0].to(device) + y_train_batch= item[1].to(device) + # project y_ into embedding space + if CKeys["Debug_TrainerPack"]==1: + print("Initial unload the dataloader items: ...") + print("X_train_batch.dim: ", X_train_batch.shape) + print("y_train_batch.dim: ", y_train_batch.shape) + + # --------------------------------------------------------- + # prepare for model.forward() to calculate loss + # --------------------------------------------------------- + # # -- + # if pLM_Model_Name=='None': + # # just use the encoded sequence + # y_train_batch_in = y_train_batch.unsqueeze(1) + # X_train_batch_in = X_train_batch.unsqueeze(1) + # # pass + # elif pLM_Model_Name=='esm2_t33_650M_UR50D': + # with torch.no_grad(): + # results = pLM_Model( + # y_train_batch, + # repr_layers=[33], + # return_contacts=False, + # ) + # y_train_batch_in = results["representations"][33] # (batch, seq_len, channels) + # y_train_batch_in = rearrange( + # y_train_batch_in, + # 'b l c -> b c l' + # ) + # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) + # else: + # print(f"Required pLM name is not defined!!") + # ++ + if pLM_Model_Name=='trivial': + # just use the encoded sequence + y_train_batch_in = y_train_batch.unsqueeze(1) + X_train_batch_in = X_train_batch.unsqueeze(1) + # pass + else: + # assume ESM models + # -- + # # for ProteinDesigner + # with torch.no_grad(): + # results = pLM_Model( + # y_train_batch, + # repr_layers=[esm_layer], + # return_contacts=False, + # ) + # y_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels) + # y_train_batch_in = rearrange( + # y_train_batch_in, + # 'b l c -> b c l' + # ) + # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) + # + # ++ + # for ProteinPredictor + with torch.no_grad(): + results = pLM_Model( + X_train_batch, + repr_layers=[esm_layer], + return_contacts=False, + ) + X_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels) + X_train_batch_in = rearrange( + X_train_batch_in, + 'b l c -> b c l' + ) + y_train_batch_in = y_train_batch.unsqueeze(1).repeat(1,image_channels,1) + + + + # + for debug + if CKeys["Debug_TrainerPack"]==1: + print("After pLM model, the shape of X and y for training:") + print("X_train_batch_in.dim: ", X_train_batch_in.shape) + print("y_train_batch_in.dim: ", y_train_batch_in.shape) + + + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if exists (trainer): + + if cond_image==False: + loss = trainer( + y_train_batch.unsqueeze(1) , # true image (batch, channels, seq_len) + x=X_train_batch, # tokenized text (batch, ) + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + # # ---------------------------------------------------------- + # if cond_image==True: + # loss = trainer( + # y_train_batch.unsqueeze(1) , # true image + # x=None, # tokenized text + # cond_images=X_train_batch.unsqueeze(1), # cond_image + # unet_number=train_unet_number, + # max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + # ) + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if cond_image==True: + loss = trainer( + y_train_batch_in, # true image + x=None, # tokenized text + cond_images=X_train_batch_in, # cond_image + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + + trainer.update(unet_number = train_unet_number) + + else: + optimizer.zero_grad() + if cond_image==False: + loss=model ( + y_train_batch.unsqueeze(1) , + x=X_train_batch, + unet_number=train_unet_number + ) + # # ------------------------------------------------------ + # if cond_image==True: + # loss=model ( + # y_train_batch.unsqueeze(1) , + # x=None, + # cond_images=X_train_batch.unsqueeze(1), + # unet_number=train_unet_number + # ) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if cond_image==True: + loss=model ( + y_train_batch_in , + x=None, + cond_images=X_train_batch_in, + unet_number=train_unet_number + ) + # + loss.backward( ) + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + + loss_total=loss_total+loss.item() + # + + train_epoch_loss=train_epoch_loss+loss.item() + + if steps % print_every == 0: + # for progress bar + print(".", end="") + + # if steps>0: + if added_steps>0: + + if steps % print_loss_every_steps == 0: + # + for debug + if CKeys['Debug_TrainerPack']==2: + print('I am here') + print("Here is steps: ", steps) + + norm_loss=loss_total/print_loss_every_steps + print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add a line to the hist file + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' + with open(train_hist_file,'a') as f: + f.write(add_line) + + + loss_list.append (norm_loss) + loss_total=0 + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + epoch_list.append(e+start_ep) + + # fig = plt.figure(figsize=(12,8),dpi=200) + fig = plt.figure() + plt.plot (epoch_list, loss_list, label='Loss') + plt.legend() + + # outname = prefix+ f"loss_{e+start_ep}_{steps}.jpg" + outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" + # + # the order, save then show, matters + if CKeys['SlientRun']==1: + plt.savefig(outname, dpi=200) + else: + plt.show() + plt.close(fig) + # plt.close() + + if added_steps>0: + # if steps>0: + if steps % sample_steps == 0: + # + for debug + if CKeys['Debug_TrainerPack']==2: + print('I am here') + print("Here is steps: ", steps) + + if plot_unscaled: + #test before scaling... + plt.plot ( + y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), + label= 'Unscaled GT' + ) + plt.legend() + plt.show() + +# # -------------------------------------------------- +# GT=y_train_batch.cpu().detach() + +# GT=resize_image_to( +# GT.unsqueeze(1), +# model.imagen.image_sizes[train_unet_number-1], +# ) + + + #### + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + print ("I. SAMPLING IN TEST SET: ") + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + #### + # num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + + + if cond_image == True: + use_text_embedd=False + # - + # cond_scales_extended=[1. for i in range(num_samples)] + # + + cond_scales_extended=cond_scales + else: + use_text_embedd=True + cond_scales_extended=cond_scales + + # # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # # For ProteinDesigner + # sample_loop_omegafold_pLM_ModelB ( + # model, + # test_loader, + # cond_scales=cond_scales_extended, # cond_scales,# #list of cond scales - each sampled... + # num_samples=num_samples, #how many samples produced every time tested..... + # timesteps=64, + # flag=steps, + # #reverse=False, + # foldproteins=foldproteins, + # use_text_embedd= use_text_embedd, + # # ++++++++++++++++++++ + # train_unet_number=train_unet_number, + # ynormfac=ynormfac, + # prefix=prefix, + # tokenizer_y=tokenizer_y, + # Xnormfac=Xnormfac, + # tokenizer_X=tokenizer_X, + # # ++ + # # ++ + # CKeys=CKeys, + # sample_dir=sample_dir, + # steps=steps, + # e=e+start_ep, + # IF_showfig= CKeys['SlientRun']!=1, + # # ++ + # pLM_Model=pLM_Model, + # pLM_Model_Name=pLM_Model_Name, + # image_channels=image_channels, + # pLM_alphabet=esm_alphabet, + # ) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # For ProteinPredictor: + val_epoch_MSE_list, \ + resu_pred, resu_grou = \ + sample_loop_omegafold_pLM_ModelB_Predictor ( + model, + test_loader, + cond_scales=[1.], # cond_scales_extended, # #list of cond scales - each sampled... + num_samples=num_samples, #how many samples produced every time tested..... + timesteps=64, + flag=steps, + #reverse=False, + foldproteins=foldproteins, + use_text_embedd= use_text_embedd, + # ++++++++++++++++++++ + train_unet_number=train_unet_number, + ynormfac=ynormfac, + prefix=prefix, + tokenizer_y=tokenizer_y, + Xnormfac=Xnormfac, + tokenizer_X=tokenizer_X, + # ++ + # ++ + CKeys=CKeys, + sample_dir=sample_dir, + steps=steps, + e=e+start_ep, + IF_showfig= CKeys['SlientRun']!=1, + # ++ + pLM_Model=pLM_Model, + pLM_Model_Name=pLM_Model_Name, + image_channels=image_channels, + pLM_alphabet=esm_alphabet, + # ++ + esm_layer=esm_layer, + ) + # record the ERROR on the test set + print(f"Epo {str(e+start_ep)}, on TestSet, MSE: {val_epoch_MSE_list[0]}") + # only write the 0th case of MSE + add_line = str(e+start_ep)+','+str(steps)+','+str(val_epoch_MSE_list[0])+'\n' + with open(train_hist_file_on_testset,'a') as f: + f.write(add_line) + + + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + print ("II. SAMPLING FOR DE NOVO: NOT USED in predictor mode") + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + +# # +++++++++++++++++++++++++++++++++++++++++ +# DeNovoSam_pdbs, fasta_file_list=\ +# sample_sequence_omegafold_pLM_ModelB ( +# model, +# x_data=test_condition_list, +# flag=steps, # flag="DeNovo", # , +# cond_scales=1., +# foldproteins=foldproteins, +# # ++++++++++ +# ynormfac=ynormfac, +# train_unet_number=train_unet_number, +# tokenizer_X=tokenizer_X, +# Xnormfac=Xnormfac, +# max_length=max_length, +# prefix=prefix, +# tokenizer_y=tokenizer_y, +# # ++ +# CKeys=CKeys, +# sample_dir=sample_dir, +# steps=steps, +# e=e+start_ep, +# IF_showfig= CKeys['SlientRun']!=1, +# # ++ +# pLM_Model=pLM_Model, +# pLM_Model_Name=pLM_Model_Name, +# image_channels=image_channels, +# pLM_alphabet=esm_alphabet, +# ) + +# if print_error and len(DeNovoSam_pdbs)>0: +# print("Calculate SecStr and design error:") +# # +# for ii in range(len(test_condition_list)): +# seq=test_condition_list[ii][0] +# DSSPresult,_,sequence_res=get_DSSP_result(DeNovoSam_pdbs[ii]) +# print (f"INPUT: {seq}\nRESULT: {DSSPresult}\nAA sequence: {sequence_res}") +# error=string_diff (DSSPresult, seq)/len (seq) +# print ("Error: ", error) + + + + # every epoch: + norm_loss_over_e = train_epoch_loss/len(train_loader) + print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # write this into "train_hist_file_full" + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' + with open(train_hist_file_full,'a') as f: + f.write(add_line) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # save model every this epoches + if save_model and (e+start_ep) % save_every_epoch==0 and e>1: + # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" + fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" + trainer.save(fname) + print (f"Model saved: ", fname) + # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" + fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" + torch.save(model.state_dict(), fname) + print (f"Statedict model saved: ", fname) + # add a saving point file + top_line='epoch,steps,norm_loss'+'\n' + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' + with open(save_point_info_file, "w") as f: + f.write(top_line) + f.write(add_line) + + + + print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") + + +def train_loop_Old_FromModelB (model, + train_loader, + test_loader, + optimizer=None, + print_every=10, + epochs= 300, + start_ep=0, + start_step=0, + train_unet_number=1, + print_loss=1000, + trainer=None, + plot_unscaled=False, + max_batch_size=4, + save_model=False, + cond_scales=[7.5], #list of cond scales - each sampled... + num_samples=2, #how many samples produced every time tested..... + foldproteins=False, + cond_image=False, #use cond_images... + # add some + # +++++++++++++++++++++++++++ + device=None, + loss_list=[], + prefix=None, + ynormfac=1, + test_condition_list=[], + tokenizer_y=None, + Xnormfac=1, + tokenizer_X=None, + max_length=1, + + ): + + if not exists (trainer): + if not exists (optimizer): + print ("ERROR: If trainer not used, need to provide optimizer.") + if exists (trainer): + print ("Trainer provided... will be used") + + steps=start_step + + loss_total=0 + for e in range(1, epochs+1): + + start = time.time() + + torch.cuda.empty_cache() + print ("######################################################################################") + start = time.time() + print ("NOW: Training epoch: ", e+start_ep) + + train_epoch_loss = 0 + model.train() + + print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") + + for item in train_loader: + + X_train_batch= item[0].to(device) + + y_train_batch=item[1].to(device) + + if exists (trainer): + if cond_image==False: + loss = trainer( + y_train_batch.unsqueeze(1) , + x=X_train_batch, + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + if cond_image==True: + + loss = trainer( + y_train_batch.unsqueeze(1) ,x=None, + cond_images=X_train_batch.unsqueeze(1), + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + trainer.update(unet_number = train_unet_number) + + else: + optimizer.zero_grad() + if cond_image==False: + loss=model (y_train_batch.unsqueeze(1) , x=X_train_batch, unet_number=train_unet_number) + if cond_image==True: + loss=model (y_train_batch.unsqueeze(1) ,x=None, cond_images=X_train_batch.unsqueeze(1), unet_number=train_unet_number) + + loss.backward( ) + + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + + optimizer.step() + + loss_total=loss_total+loss.item() + + if steps % print_every == 0: + print(".", end="") + + if steps>0: + if steps % print_loss == 0: + + if plot_unscaled: + #test before scaling... + plt.plot (y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),label= 'Unscaled GT') + plt.legend() + plt.show() + + + GT=y_train_batch.cpu().detach() + + GT=resize_image_to( + GT.unsqueeze(1), + model.imagen.image_sizes[train_unet_number-1], + + ) + + norm_loss=loss_total/print_loss + print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}") + + loss_list.append (norm_loss) + loss_total=0 + + plt.plot (loss_list, label='Loss') + plt.legend() + + outname = prefix+ f"loss_{e}_{steps}.jpg" + plt.savefig(outname, dpi=200) + plt.show() + + #### + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + + + if cond_image == True: + use_text_embedd=False + else: + use_text_embedd=True + + sample_loop_FromModelB ( + model, + test_loader, + cond_scales=cond_scales,# #list of cond scales - each sampled... + num_samples=num_samples, #how many samples produced every time tested..... + timesteps=64, + flag=steps, + #reverse=False, + foldproteins=foldproteins, + use_text_embedd= use_text_embedd, + # ++++++++++++++++++++ + train_unet_number=train_unet_number, + ynormfac=ynormfac, + prefix=prefix, + tokenizer_y=tokenizer_y, + Xnormfac=Xnormfac, + tokenizer_X=tokenizer_X, + ) + + #index_word': '{"1": "~", "2": "h", "3": "e", "4": "s", "5": "t", "6": "g", "7": "b", "8": "i"}', + #'word_index': '{"~": 1, "h": 2, "e": 3, "s": 4, "t": 5, "g": 6, "b": 7, "i": 8}'} + + AH_code=2/Xnormfac + BS_code=3/Xnormfac + unstr_code= 1/Xnormfac + + print ("SAMPLING FOR DE NOVO:") + + # +++++++++++++++++++++++++++++++++++++++++ + for this_x_data in test_condition_list: + sample_sequence_FromModelB ( + model, + x_data=this_x_data, + flag=steps,cond_scales=1., + foldproteins=True, + # ++++++++++ + ynormfac=ynormfac, + train_unet_number=train_unet_number, + tokenizer_X=tokenizer_X, + Xnormfac=Xnormfac, + max_length=max_length, + prefix=prefix, + tokenizer_y=tokenizer_y, + ) + # ----------------------------------------- + # sample_sequence ( + # model, + # x_data=['~~~HHHHHHHHHHHHHHH~~'], + # flag=steps,cond_scales=1., + # foldproteins=True, + # # ++++++++++ + # ynormfac=ynormfac, + # ) + # sample_sequence ( + # model, + # x_data=['~~~HHHHHHHHHHHHHHH~~~~HHHHHHHHHHHHHH~~~'], + # flag=steps,cond_scales=1., + # foldproteins=True, + # # ++++++++++ + # ynormfac=ynormfac, + # ) + # sample_sequence ( + # model, + # x_data=['~~EEESSTTS~SEEEEEEEEE~SBS~EEEEEE~~'], + # flag=steps,cond_scales=1., + # foldproteins=True, + # # ++++++++++++ + # ynormfac=ynormfac, + # ) + + if steps>0: + if save_model and steps % print_loss==0: + fname=f"{prefix}trainer_save-model-epoch_{e}.pt" + trainer.save(fname) + print (f"Model saved: ", fname) + fname=f"{prefix}statedict_save-model-epoch_{e}.pt" + torch.save(model.state_dict(), fname) + print (f"Statedict model saved: ", fname) + + steps=steps+1 + + print (f"\n\n-------------------\nTime for epoch {e}={(time.time()-start)/60}\n-------------------") + + +# ++++++++++++++++++++++++++++++++++++++++++++++ +def foldandsavePDB_pdb_fasta ( + sequence, + filename_out, + num_cycle=16, + flag=0, + # ++++++++++++ + prefix=None, +): + + filename=f"{prefix}fasta_in_{flag}.fasta" + print ("Writing FASTA file: ", filename) + OUTFILE=f"{filename_out}_{flag}" + with open (filename, mode ='w') as f: + f.write (f'>{OUTFILE}\n') + f.write (f'{sequence}') + + print (f"Now run OmegaFold.... on device={device}") + # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device + cmd_line=F"omegafold {filename} {prefix} --num_cycle {num_cycle} --device={device}" + print(os.popen(cmd_line).read()) + + print ("Done OmegaFold") + + # PDB_result=f"{prefix}{OUTFILE}.PDB" + PDB_result=f"{prefix}{OUTFILE}.pdb" + print (f"Resulting PDB file...: {PDB_result}") + + return PDB_result, filename + + + +def foldandsavePDB ( + sequence, + filename_out, + num_cycle=16, + flag=0, + # ++++++++++++ + prefix=None, +): + + filename=f"{prefix}fasta_in_{flag}.fasta" + print ("Writing FASTA file: ", filename) + OUTFILE=f"{filename_out}_{flag}" + with open (filename, mode ='w') as f: + f.write (f'>{OUTFILE}\n') + f.write (f'{sequence}') + + print (f"Now run OmegaFold.... on device={device}") + # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device + cmd_line=F"omegafold {filename} {prefix} --num_cycle {num_cycle} --device={device}" + print(os.popen(cmd_line).read()) + + print ("Done OmegaFold") + + # PDB_result=f"{prefix}{OUTFILE}.PDB" + PDB_result=f"{prefix}{OUTFILE}.pdb" + print (f"Resulting PDB file...: {PDB_result}") + + return PDB_result + +import py3Dmol +def plot_plddt_legend(dpi=100): + thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)'] + plt.figure(figsize=(1,0.1),dpi=dpi) + ######################################## + for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]: + plt.bar(0, 0, color=c) + plt.legend(thresh, frameon=False, + loc='center', ncol=6, + handletextpad=1, + columnspacing=1, + markerscale=0.5,) + plt.axis(False) + return plt +color = "lDDT" # choose from ["chain", "lDDT", "rainbow"] +show_sidechains = False #choose from {type:"boolean"} +show_mainchains = False #choose from {type:"boolean"} + +def show_pdb(pdb_file, flag=0, show_sidechains=False, show_mainchains=False, color="lDDT"): + model_name = f"Flag_{flag}" + view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',) + view.addModel(open(pdb_file,'r').read(),'pdb') + + if color == "lDDT": + view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}}) + elif color == "rainbow": + view.setStyle({'cartoon': {'color':'spectrum'}}) + elif color == "chain": + chains = len(queries[0][1]) + 1 if is_complex else 1 + for n,chain,color in zip(range(chains),list("ABCDEFGH"), + ["lime","cyan","magenta","yellow","salmon","white","blue","orange"]): + view.setStyle({'chain':chain},{'cartoon': {'color':color}}) + if show_sidechains: + BB = ['C','O','N'] + view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]}, + {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]}, + {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]}, + {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + if show_mainchains: + BB = ['C','O','N','CA'] + view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + + view.zoomTo() + if color == "lDDT": + plot_plddt_legend().show() + return view + +def get_avg_Bfac (file='./output_v3/[0.0 0.5 0.0 0.0 0.0 0.0 0.0 0.0].pdb'): + p = PDBParser() + avg_B=0 + bfac_list=[] + + structure = p.get_structure("X", file) + for PDBmodel in structure: + for chain in PDBmodel: + for residue in chain: + for atom in residue: + + Bfac=atom.get_bfactor() + bfac_list.append(Bfac) + avg_B=avg_B+Bfac + + avg_B=avg_B/len (bfac_list) + print (f"For {file}, average B-factor={avg_B}") + plt.plot (bfac_list, label='lDDT') + plt.xlabel ('Atom #' ) + plt.ylabel ('iDDT') + plt.legend() + plt.show() + return avg_B, bfac_list + +def sample_sequence_normalized_Bfac (seccs=[0.3, 0.3, 0.1, 0., 0., 0., 0., 0. ]): + sample_numbers=torch.tensor([seccs]) + sample_numbers=torch.nn.functional.normalize (sample_numbers, dim=1) + sample_numbers=sample_numbers/torch.sum(sample_numbers) + + PDB=sample_sequence (model, + X=sample_numbers, + flag=0,cond_scales=1, foldproteins=True, + ) + + avg,_ = get_avg_Bfac (file=PDB[0]) + + return PDB, avg + +# ====================================================== +# blocks for Model A +# ====================================================== +def train_loop_Old_FromModelA ( + model, + train_loader, + test_loader, + # + optimizer=None, + print_every=1, + epochs= 300, + start_ep=0, + start_step=0, + train_unet_number=1, + print_loss_every_steps=1000, + # + trainer=None, + plot_unscaled=False, + max_batch_size=4, + save_model=False, + cond_scales=[1.0], #list of cond scales + num_samples=2, #how many samples produced every time tested..... + foldproteins=False, + # ++ + cond_image=False, # not use cond_images... for model A + cond_text=True, # use condi_text... for model A + # + + device=None, + loss_list=[], + epoch_list=[], + train_hist_file=None, + train_hist_file_full=None, + prefix=None, # not used in this function + Xnormfac=None, + ynormfac=1., + tokenizer_X=None, + tokenizer_y=None, + test_condition_list=[], + max_length_Y=1, + max_text_len_X=1, + CKeys=None, + sample_steps=1, + sample_dir=None, + save_every_epoch=1, + save_point_info_file=None, + store_dir=None, +): + # #+ + # Xnormfac=Xnormfac.to(model.device) + + if not exists (trainer): + if not exists (optimizer): + print ("ERROR: If trainer not used, need to provide optimizer.") + if exists (trainer): + print ("Trainer provided... will be used") + # -------------------------------- + # steps=start_step + # ++++++++++++++++++++++++++++++++ + steps=start_step + added_steps=0 + + loss_total=0 + for e in range(1, epochs+1): + # start = time.time() + + torch.cuda.empty_cache() + print ("######################################################################################") + start = time.time() + print ("NOW: Training epoch: ", e+start_ep) + + # TRAINING + train_epoch_loss = 0 + model.train() + + print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") + + for item in train_loader: + # ++ + steps += 1 + added_steps += 1 + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + if exists (trainer): + if cond_image==False: + # ======================================== + # Model A: condition via text + # ======================================== + # this block depends on the model:forward + loss = trainer( + # # -------------------------------- + # X_train_batch, + # y_train_batch.unsqueeze(1) , + # ++++++++++++++++++++++++++++++++ + y_train_batch.unsqueeze(1) , + x=X_train_batch, + # + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + if cond_image==True: + # ======================================== + # Model B: condition via image/sequence + # ======================================== + # added for future: Train_loop B + loss = trainer( + y_train_batch.unsqueeze(1) , + x=None, + cond_images=X_train_batch.unsqueeze(1), + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + # pass + # + trainer.update(unet_number = train_unet_number) + + else: + optimizer.zero_grad() + if cond_image==False: + # this block depends on the model:forward + loss=model ( + # # -------------------------------- + # X_train_batch, + # y_train_batch.unsqueeze(1) , + # ++++++++++++++++++++++++++++++++ + y_train_batch.unsqueeze(1) , + x=X_train_batch, + # + unet_number=train_unet_number + ) + if cond_image==True: + # added for future: Train_loop B + loss=model ( + y_train_batch.unsqueeze(1) , + x=None, + cond_images=X_train_batch.unsqueeze(1), + unet_number=train_unet_number + ) + # + loss.backward( ) + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + + loss_total=loss_total+loss.item() + # + + train_epoch_loss=train_epoch_loss+loss.item() + + if steps % print_every == 0: + # for progress bar + print(".", end="") + + # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + # record loss block + # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + # if steps>0: + if added_steps>0: + + if steps % print_loss_every_steps == 0: + # + for debug + if CKeys['Debug_TrainerPack']==2: + print("Here is step: ", steps) + + norm_loss=loss_total/print_loss_every_steps + print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add a line to the hist file + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' + with open(train_hist_file,'a') as f: + f.write(add_line) + + loss_list.append (norm_loss) + loss_total=0 + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + epoch_list.append(e+start_ep) + + fig = plt.figure() + plt.plot (epoch_list, loss_list, label='Loss') + plt.legend() + # outname = prefix+ f"loss_{e}_{steps}.jpg" + outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" + # + # the order, save then show, matters + if CKeys['SlientRun']==1: + plt.savefig(outname, dpi=200) + else: + plt.show() + plt.close(fig) + + # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + # sample in test set block + # set sample_steps < 0 to switch off this block + # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + # if steps>0: + if added_steps>0: + if steps % sample_steps == 0 and sample_steps > 0: + # + for debug + if CKeys['Debug_TrainerPack']==2: + print("Here is steps: ", steps) + + if plot_unscaled: + # test before scaling... + plt.plot ( + y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), + label= 'Unscaled GT' + ) + plt.legend() + plt.show() + + #rescale GT to properly plot + GT=y_train_batch.cpu().detach() + + GT=resize_image_to( + GT.unsqueeze(1), + model.imagen.image_sizes[train_unet_number-1], + + ) + #### + print ("I. SAMPLING IN TEST SET: ") + #### + + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + + sample_loop_omegafold_ModelA ( + model, + test_loader, + cond_scales=cond_scales, + num_samples=num_samples, #how many samples produced every time tested..... + timesteps=None, + flag=e+start_ep, # steps, + foldproteins=foldproteins, + # add condi_key + cond_image=cond_image, # Not used for now + cond_text=cond_text, # Not used for now + skip_steps=0, + # + max_text_len=max_text_len_X, + max_length=max_length_Y, + # ++++++++++++++++++++ + train_unet_number=train_unet_number, + ynormfac=ynormfac, + prefix=prefix, # + tokenizer_y=tokenizer_y, + Xnormfac_CondiText=Xnormfac, + tokenizer_X_CondiText=tokenizer_X, + # ++ + CKeys=CKeys, + sample_dir=sample_dir, + steps=steps, + e=e+start_ep, + IF_showfig= CKeys['SlientRun']!=1 , + ) + + print ("II. SAMPLING FOR DE NOVO:") + + sample_sequence_omegafold_ModelA ( + # # ---------------------------------------------- + # model, + # X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]], + # foldproteins=foldproteins, + # flag=steps,cond_scales=1., + # ++++++++++++++++++++++++++++++++++++++++++++++ + model, + X=test_condition_list, # [[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X + flag=e+start_ep, # steps, # 0, + cond_scales=cond_scales, # 1., + foldproteins=True, # False, + X_string=None, # from text conditioning X_string + x_data=None, # from image conditioning x_data + skip_steps=0, + inpaint_images=None, # in formation Y data + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, # for omegafolding + calc_error=False, # for check on folded results, not used for every case + # ++++++++++++++++++++++++++ + # tokenizers + tokenizer_X_forImageCondi=None, # for x_data + Xnormfac_forImageCondi=1., + tokenizer_X_forTextCondi=None, # for X if NEEDED only + Xnormfac_forTextCondi=1., + tokenizer_y=tokenizer_y, # None, # for output Y + ynormfac=ynormfac, + # length + train_unet_number=1, + max_length_Y=max_length_Y, # for Y, X_forImageCondi + max_text_len=max_text_len_X, # for X_forTextCondi + # other info + steps=steps, # None, + e=e, # None, + sample_dir=sample_dir, # None, + prefix=prefix, # None, + IF_showfig= CKeys['SlientRun']!=1, # True, + CKeys=CKeys, + # TBA to Model B + normalize_X_cond_to_one=False, + ) + + # sample_sequence (model, + # X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins, + # flag=steps,cond_scales=1., + # ) + + # summerize loss over every epoch: + norm_loss_over_e = train_epoch_loss/len(train_loader) + print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # write this into "train_hist_file_full" + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' + with open(train_hist_file_full,'a') as f: + f.write(add_line) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # save model every this epoches + if save_model and (e+start_ep) % save_every_epoch==0 and e>1: + # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" + fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" + trainer.save(fname) + print (f"Model saved: ", fname) + # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" + fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" + torch.save(model.state_dict(), fname) + print (f"Statedict model saved: ", fname) + # add a saving point file + top_line='epoch,steps,norm_loss'+'\n' + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' + with open(save_point_info_file, "w") as f: + f.write(top_line) + f.write(add_line) + + # if steps>0: + # if save_model and steps % print_loss_every_steps==0: + # fname=f"{prefix}trainer_save-model-epoch_{e}.pt" + # trainer.save(fname) + # fname=f"{prefix}statedict_save-model-epoch_{e}.pt" + # torch.save(model.state_dict(), fname) + # print (f"Model saved: ") + + # steps=steps+1 + + print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") + +def train_loop_ForModelA_II ( + model, + train_loader, + test_loader, + # + optimizer=None, + print_every=1, + epochs= 300, + start_ep=0, + start_step=0, + train_unet_number=1, + print_loss_every_steps=1000, + # + trainer=None, + plot_unscaled=False, + max_batch_size=4, + save_model=False, + cond_scales=[1.0], #list of cond scales + num_samples=2, #how many samples produced every time tested..... + foldproteins=False, + # ++ + cond_image=False, # not use cond_images... for model A + cond_text=True, # use condi_text... for model A + # + + device=None, + loss_list=[], + epoch_list=[], + train_hist_file=None, + train_hist_file_full=None, + prefix=None, # not used in this function + Xnormfac=None, + ynormfac=1., + tokenizer_X=None, + tokenizer_y=None, + test_condition_list=[], + max_length_Y=1, + max_text_len_X=1, + CKeys=None, + sample_steps=1, + sample_dir=None, + save_every_epoch=1, + save_point_info_file=None, + store_dir=None, + # ++ for pLM + pLM_Model_Name=None, + image_channels=None, + print_error=False, # not defined for Problem6 # True, +): + # #+ + # Xnormfac=Xnormfac.to(model.device) + + if not exists (trainer): + if not exists (optimizer): + print ("ERROR: If trainer not used, need to provide optimizer.") + if exists (trainer): + print ("Trainer provided... will be used") + # -------------------------------- + # steps=start_step + # ++++++++++++++++++++++++++++++++ + steps=start_step + added_steps=0 + + loss_total=0 + + # ++ for pLM + if pLM_Model_Name=='None': + pLM_Model=None + + elif pLM_Model_Name=='esm2_t33_650M_UR50D': + # dim: 1280 + esm_layer=33 + pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() + len_toks=len(esm_alphabet.all_toks) + pLM_Model.eval() + pLM_Model. to(device) + + elif pLM_Model_Name=='esm2_t36_3B_UR50D': + # dim: 2560 + esm_layer=36 + pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() + len_toks=len(esm_alphabet.all_toks) + pLM_Model.eval() + pLM_Model. to(device) + + elif pLM_Model_Name=='esm2_t30_150M_UR50D': + # dim: 640 + esm_layer=30 + pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() + len_toks=len(esm_alphabet.all_toks) + pLM_Model.eval() + pLM_Model. to(device) + + elif pLM_Model_Name=='esm2_t12_35M_UR50D': + # dim: 480 + esm_layer=12 + pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() + len_toks=len(esm_alphabet.all_toks) + pLM_Model.eval() + pLM_Model. to(device) + + else: + print("pLM model is missing...") + + + for e in range(1, epochs+1): + # start = time.time() + + torch.cuda.empty_cache() + print ("######################################################################################") + start = time.time() + print ("NOW: Training epoch: ", e+start_ep) + + # TRAINING + train_epoch_loss = 0 + model.train() + + print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") + + for item in train_loader: + # ++ + steps += 1 + added_steps += 1 + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + # project y_ into embedding space + if CKeys["Debug_TrainerPack"]==1: + print("Initial unload the dataloader items: ...") + print(X_train_batch.shape) + print(y_train_batch.shape) + # ++ + # project the AA seq into embedding space + # for output, it is shared between ModelA and ModelB + # # -- + # if pLM_Model_Name=='None': + # # just use the encoded sequence + # y_train_batch_in = y_train_batch.unsqueeze(1) + # # pass + # elif pLM_Model_Name=='esm2_t33_650M_UR50D': + # with torch.no_grad(): + # results = pLM_Model( + # y_train_batch, + # repr_layers=[33], + # return_contacts=False, + # ) + # y_train_batch_in = results["representations"][33] + # y_train_batch_in = rearrange( + # y_train_batch_in, + # 'b l c -> b c l' + # ) + # else: + # print(f"Required pLM name is not defined!!") + # ++ + if pLM_Model_Name=='None': + # just use the encoded sequence + y_train_batch_in = y_train_batch.unsqueeze(1) + # pass + else: # for ESM models # pLM_Model_Name=='esm2_t33_650M_UR50D': + with torch.no_grad(): + results = pLM_Model( + y_train_batch, + repr_layers=[esm_layer], + return_contacts=False, + ) + y_train_batch_in = results["representations"][esm_layer] + y_train_batch_in = rearrange( + y_train_batch_in, + 'b l c -> b c l' + ) + + + # + # For input part, this block is different for ModelA and ModelB + if cond_image==False: + # model A: X: text_condi, not affected by pLM + X_train_batch_in = X_train_batch + else: + # model B: X: cond_img, will be affected by pLM + X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) + # + # + for debug + if CKeys["Debug_TrainerPack"]==1: + print("After pLM model, the shape of X and y for training:") + print("X_train_batch_in.dim: ", X_train_batch_in.shape) + print("y_train_batch_in.dim: ", y_train_batch_in.shape) + + + + + if exists (trainer): + if cond_image==False: + # ======================================== + # Model A: condition via text + # ======================================== + # this block depends on the model:forward + loss = trainer( + # # -------------------------------- + # X_train_batch, + # y_train_batch.unsqueeze(1) , + # # ++++++++++++++++++++++++++++++++ + # y_train_batch.unsqueeze(1) , + # x=X_train_batch, + # ++ pLM + y_train_batch_in, + x=X_train_batch_in, + # + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + if cond_image==True: + # ======================================== + # Model B: condition via image/sequence + # ======================================== + # # -- + # # added for future: Train_loop B + # loss = trainer( + # y_train_batch.unsqueeze(1) , + # x=None, + # cond_images=X_train_batch.unsqueeze(1), + # unet_number=train_unet_number, + # max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + # ) + # ++ from pLM+ModelB + loss = trainer( + y_train_batch_in, # true image + x=None, # tokenized text + cond_images=X_train_batch_in, # cond_image + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + # pass + # + trainer.update(unet_number = train_unet_number) + + else: + optimizer.zero_grad() + if cond_image==False: + # this block depends on the model:forward + loss=model ( + # # -------------------------------- + # X_train_batch, + # y_train_batch.unsqueeze(1) , + # # ++++++++++++++++++++++++++++++++ + # y_train_batch.unsqueeze(1) , + # x=X_train_batch, + # ++ pLM + y_train_batch_in, + x=X_train_batch_in, + # + unet_number=train_unet_number + ) + if cond_image==True: + # added for future: Train_loop B + # # -- + # loss=model ( + # y_train_batch.unsqueeze(1) , + # x=None, + # cond_images=X_train_batch.unsqueeze(1), + # unet_number=train_unet_number + # ) + # ++ from pLM + loss=model ( + y_train_batch_in , + x=None, + cond_images=X_train_batch_in, + unet_number=train_unet_number + ) + # + loss.backward( ) + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + + loss_total=loss_total+loss.item() + # + + train_epoch_loss=train_epoch_loss+loss.item() + + if steps % print_every == 0: + # for progress bar + print(".", end="") + + # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + # record loss block + # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + # if steps>0: + if added_steps>0: + + if steps % print_loss_every_steps == 0: + # + for debug + if CKeys['Debug_TrainerPack']==2: + print("Here is step: ", steps) + + norm_loss=loss_total/print_loss_every_steps + print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add a line to the hist file + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' + with open(train_hist_file,'a') as f: + f.write(add_line) + + loss_list.append (norm_loss) + loss_total=0 + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + epoch_list.append(e+start_ep) + + fig = plt.figure() + plt.plot (epoch_list, loss_list, label='Loss') + plt.legend() + # outname = prefix+ f"loss_{e}_{steps}.jpg" + outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" + # + # the order, save then show, matters + if CKeys['SlientRun']==1: + plt.savefig(outname, dpi=200) + else: + plt.show() + plt.close(fig) + + # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + # sample in test set block + # set sample_steps < 0 to switch off this block + # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + # if steps>0: + if added_steps>0: + if steps % sample_steps == 0 and sample_steps > 0: + # + for debug + if CKeys['Debug_TrainerPack']==2: + print("Here is steps: ", steps) + + if plot_unscaled: + # test before scaling... + plt.plot ( + y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), + label= 'Unscaled GT' + ) + plt.legend() + plt.show() + +# # -- look like not used +# #rescale GT to properly plot +# GT=y_train_batch.cpu().detach() + +# GT=resize_image_to( +# GT.unsqueeze(1), +# model.imagen.image_sizes[train_unet_number-1], + +# ) + #### + print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ") + print ("I. SAMPLING IN TEST SET: ") + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + #### + + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + + sample_loop_omegafold_pLM_ModelA ( + model, + test_loader, + cond_scales=cond_scales, + num_samples=num_samples, #how many samples produced every time tested..... + timesteps=None, + flag=e+start_ep, # steps, + foldproteins=foldproteins, + # add condi_key + cond_image=cond_image, # Not used for now + cond_text=cond_text, # Not used for now + skip_steps=0, + # + max_text_len=max_text_len_X, + max_length=max_length_Y, + # ++++++++++++++++++++ + train_unet_number=train_unet_number, + ynormfac=ynormfac, + prefix=prefix, # + tokenizer_y=tokenizer_y, + Xnormfac_CondiText=Xnormfac, + tokenizer_X_CondiText=tokenizer_X, + # ++ + CKeys=CKeys, + sample_dir=sample_dir, + steps=steps, + e=e+start_ep, + IF_showfig= CKeys['SlientRun']!=1 , + # ++ for pLM + pLM_Model=pLM_Model, + pLM_Model_Name=pLM_Model_Name, + image_channels=image_channels, + pLM_alphabet=esm_alphabet, + ) + + print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ") + print ("II. SAMPLING FOR DE NOVO:") + print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") + + DeNovoSam_pdbs, fasta_file_list=\ + sample_sequence_omegafold_pLM_ModelA ( + # # ---------------------------------------------- + # model, + # X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]], + # foldproteins=foldproteins, + # flag=steps,cond_scales=1., + # ++++++++++++++++++++++++++++++++++++++++++++++ + model, + X=test_condition_list, # [[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X + flag=e+start_ep, # steps, # 0, + cond_scales=cond_scales, # 1., + foldproteins=True, # False, + X_string=None, # from text conditioning X_string + x_data=None, # from image conditioning x_data + skip_steps=0, + inpaint_images=None, # in formation Y data + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, # for omegafolding + calc_error=False, # for check on folded results, not used for every case + # ++++++++++++++++++++++++++ + # tokenizers + tokenizer_X_forImageCondi=None, # for x_data + Xnormfac_forImageCondi=1., + tokenizer_X_forTextCondi=None, # for X if NEEDED only + Xnormfac_forTextCondi=1., + tokenizer_y=tokenizer_y, # None, # for output Y + ynormfac=ynormfac, + # length + train_unet_number=1, + max_length_Y=max_length_Y, # for Y, X_forImageCondi + max_text_len=max_text_len_X, # for X_forTextCondi + # other info + steps=steps, # None, + e=e, # None, + sample_dir=sample_dir, # None, + prefix=prefix, # None, + IF_showfig= CKeys['SlientRun']!=1, # True, + CKeys=CKeys, + # TBA to Model B + normalize_X_cond_to_one=False, + # ++ for pLM + pLM_Model=pLM_Model, + pLM_Model_Name=pLM_Model_Name, + image_channels=image_channels, + pLM_alphabet=esm_alphabet, + ) + + # sample_sequence (model, + # X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins, + # flag=steps,cond_scales=1., + # ) + + # summerize loss over every epoch: + norm_loss_over_e = train_epoch_loss/len(train_loader) + print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # write this into "train_hist_file_full" + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' + with open(train_hist_file_full,'a') as f: + f.write(add_line) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # save model every this epoches + if save_model and (e+start_ep) % save_every_epoch==0 and e>1: + # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt" + fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" + trainer.save(fname) + print (f"Model saved: ", fname) + # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt" + fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" + torch.save(model.state_dict(), fname) + print (f"Statedict model saved: ", fname) + # add a saving point file + top_line='epoch,steps,norm_loss'+'\n' + add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' + with open(save_point_info_file, "w") as f: + f.write(top_line) + f.write(add_line) + + # if steps>0: + # if save_model and steps % print_loss_every_steps==0: + # fname=f"{prefix}trainer_save-model-epoch_{e}.pt" + # trainer.save(fname) + # fname=f"{prefix}statedict_save-model-epoch_{e}.pt" + # torch.save(model.state_dict(), fname) + # print (f"Model saved: ") + + # steps=steps+1 + + print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") + +# from original, not used any more +def train_loop_Model_A ( + model, + train_loader, + test_loader, + optimizer=None, + print_every=10, + epochs= 300, + start_ep=0, + start_step=0, + train_unet_number=1, + print_loss=1000, + trainer=None, + plot_unscaled=False, + max_batch_size=4, + save_model=False, + cond_scales=[1.0], #list of cond scales + num_samples=2, #how many samples produced every time tested..... + foldproteins=False, +): + + + if not exists (trainer): + if not exists (optimizer): + print ("ERROR: If trainer not used, need to provide optimizer.") + if exists (trainer): + print ("Trainer provided... will be used") + steps=start_step + + loss_total=0 + for e in range(1, epochs+1): + start = time.time() + + torch.cuda.empty_cache() + print ("######################################################################################") + start = time.time() + print ("NOW: Training epoch: ", e+start_ep) + + # TRAINING + train_epoch_loss = 0 + model.train() + + print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") + for item in train_loader: + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + if exists (trainer): + loss = trainer( + X_train_batch, y_train_batch.unsqueeze(1) , + unet_number=train_unet_number, + max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory + ) + trainer.update(unet_number = train_unet_number) + + else: + optimizer.zero_grad() + loss=model ( X_train_batch, y_train_batch.unsqueeze(1) ,unet_number=train_unet_number) + loss.backward( ) + + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + + optimizer.step() + + loss_total=loss_total+loss.item() + + if steps % print_every == 0: + print(".", end="") + + if steps>0: + if steps % print_loss == 0: + + if plot_unscaled: + + plt.plot (y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),label= 'Unscaled GT') + plt.legend() + plt.show() + + #rescale GT to properly plot + GT=y_train_batch.cpu().detach() + + GT=resize_image_to( + GT.unsqueeze(1), + model.imagen.image_sizes[train_unet_number-1], + + ) + + norm_loss=loss_total/print_loss + print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}") + + loss_list.append (norm_loss) + loss_total=0 + + plt.plot (loss_list, label='Loss') + plt.legend() + + outname = prefix+ f"loss_{e}_{steps}.jpg" + plt.savefig(outname, dpi=200) + plt.show() + + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + + sample_loop (model, + test_loader, + cond_scales=cond_scales, + num_samples=1, #how many samples produced every time tested..... + timesteps=64, + flag=steps,foldproteins=foldproteins, + ) + + print ("SAMPLING FOR DE NOVO:") + sample_sequence (model, + X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]],foldproteins=foldproteins, + flag=steps,cond_scales=1., + ) + sample_sequence (model, + X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins, + flag=steps,cond_scales=1., + ) + + if steps>0: + if save_model and steps % print_loss==0: + fname=f"{prefix}trainer_save-model-epoch_{e}.pt" + trainer.save(fname) + fname=f"{prefix}statedict_save-model-epoch_{e}.pt" + torch.save(model.state_dict(), fname) + print (f"Model saved: ") + + steps=steps+1 + + print (f"\n\n-------------------\nTime for epoch {e}={(time.time()-start)/60}\n-------------------") + +# +++ +def sample_sequence_omegafold_ModelA ( + model, + X=[[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X + flag=0, + cond_scales=1., + foldproteins=False, + X_string=None, # from text conditioning X_string + x_data=None, # from image conditioning x_data + skip_steps=0, + inpaint_images=None, # in formation Y data + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, # for omegafolding + calc_error=False, # for check on folded results, not used for every case + # ++++++++++++++++++++++++++ + # tokenizers + tokenizer_X_forImageCondi=None, # for x_data + Xnormfac_forImageCondi=1., + tokenizer_X_forTextCondi=None, # for X if NEEDED only + Xnormfac_forTextCondi=1., + tokenizer_y=None, # for output Y + ynormfac=1, + # length + train_unet_number=1, + max_length_Y=1, # for Y, X_forImageCondi + max_text_len=1, # for X_forTextCondi + # other info + steps=None, + e=None, + sample_dir=None, + prefix=None, + IF_showfig=True, + CKeys=None, + # TBA to Model B + normalize_X_cond_to_one=False, +): + # ----------- + # steps=0 + # e=flag + + # -- + # print (f"Producing {len(X)} samples...") + # ++ + if X!=None: + print (f"Producing {len(X)} samples...from text conditioning X...") + lenn_val=len(X) + if X_string!=None: + lenn_val=len(X_string) + print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") + if x_data!=None: + print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") + lenn_val=len(x_data) + # print (x_data) + + print ('Device: ', model.device) + + + for iisample in range (lenn_val): + print(f"Working on {iisample}") + X_cond=None + + if X_string==None and X!=None: # for X channel + X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + if X_string!=None: # from raw text, ie., X_string: need tokenizer_X and Xnormfac + # - + # X = tokenizer_X.texts_to_sequences(X_string[iisample]) + # X= sequence.pad_sequences(X, maxlen=max_length, padding='post', truncating='post') + # X=np.array(X) + # X_cond=torch.from_numpy(X).float()/Xnormfac + # + + XX = tokenizer_X_forTextCondi.texts_to_sequences(X_string[iisample]) + XX = sequence.pad_sequences(XX, maxlen=max_text_len, padding='post', truncating='post') + XX = np.array(XX) + X_cond = torch.from_numpy(XX).float()/Xnormfac_forTextCondi + + print ('Tokenized and processed: ', X_cond) + + if X_cond!=None: + if normalize_X_cond_to_one: # used when there is constrain on X_cond.sum() + X_cond=X_cond/X_cond.sum() + + print ("Text conditoning used: ", X_cond, "...sum: ", X_cond.sum(), "cond scale: ", cond_scales) + else: + print ("Text conditioning used: None") + + # for now, assume image_condi and text_condi can be used at the same time + if tokenizer_X_forImageCondi==None: + # =========================================================== + # condi_image/seq needs no tokenization, like numbers: force_path + # only normalization needed + # Based on ModelB:Force_Path + if x_data!=None: + x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac_forImageCondi) + x_data_tokenized=x_data_tokenized.to(torch.float) + # + for debug: + if CKeys['Debug_TrainerPack']==1: + print("x_data_tokenized dim: ", x_data_tokenized.shape) + print("x_data_tokenized dtype: ", x_data_tokenized.dtype) + print("test: ", x_data_tokenized!=None) + else: + x_data_tokenized=None + # + for debug: + if CKeys['Debug_TrainerPack']==1: + print("x_data_tokenized and x_data: None") + + # model.sample:full arguments + # self, + # x=None, + # stop_at_unet_number=1, + # cond_scale=7.5, + # # ++ + # x_data=None, # image_condi data + # skip_steps=None, + # inpaint_images = None, + # inpaint_masks = None, + # inpaint_resample_times = 5, + # init_images = None, + # x_data_tokenized=None, + # tokenizer_X=None, + # Xnormfac=1., + # # -+ + # device=None, + # max_length=1., # for XandY data, in image/sequence format; NOT for text condition + # max_text_len=1., # for X data, in text format + # + result=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=None, + # ++ + x_data_tokenized=x_data_tokenized, + # + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images, + device=model.device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X, + Xnormfac=Xnormfac_forImageCondi, # Xnormfac, + # ynormfac=ynormfac, + max_length=max_length_Y, # for ImageCondi, max_length, + max_text_len=max_text_len, + ) + else: + # + result=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=x_data[iisample], + # ++ + x_data_tokenized=None, + # + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images, + device=model.device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X, + Xnormfac=Xnormfac_forImageCondi, # Xnormfac, + # ynormfac=ynormfac, + max_length=max_length_Y, # max_length, + max_text_len=max_text_len, + ) + + # # ------------------------------------------ + # result=model.sample ( + # X_cond, + # stop_at_unet_number=train_unet_number, + # cond_scale=cond_scales + # ) + + result=torch.round(result*ynormfac) + # + for debug + print("result.dim: ", result.shape) + + fig=plt.figure() + plt.plot ( + result[0,0,:].cpu().detach().numpy(), + label= f'Predicted' + ) + #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') + plt.legend() + outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" + #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") + if IF_showfig==1: + plt.show () + else: + plt.savefig(outname, dpi=200) + plt.close() + + # # ---------------------------------------- + # plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted') + # plt.legend() + # outname = prefix+ f"sampld_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" + # plt.savefig(outname, dpi=200) + # plt.show () + + to_rev=result[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + print("to_rev.dim: ", to_rev.shape) + y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + # + from Model B + ### reverse second structure input.... + pdb_list=[] + if X_cond != None: + # there is condi_text + if X_string!=None: + X_cond=torch.round(X_cond*Xnormfac_forTextCondi) + + to_rev=X_cond[:,:] + to_rev=to_rev.long().cpu().detach().numpy() + print ("to_rev.dim: ", to_rev.shape) + # -- + # X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + # ++ + X_text_reversed=tokenizer_X_forTextCondi.sequences_to_texts (to_rev) + for iii in range (len(y_text_reversed)): + X_text_reversed[iii]=X_text_reversed[iii].upper().strip().replace(" ", "") + + if X_string==None: + # reverse this: X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + X_text_reversed=X_cond + else: + X_text_reversed=None + + if x_data !=None: # there is condi_image + x_data_reversed=x_data #is already in sequence fromat.. + else: + x_data_reversed=None + + # summary + # print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,\n predicted sequence: ", y_data_reversed) + print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,") + print (f"predicted sequence full: {y_data_reversed}") + # add just for incase check + print (f"predicted sequence: {y_data_reversed[0]}") + + # + for debug + print("================================================") + print("foldproteins: ", foldproteins) + + if not foldproteins: + pdb_file=None + else: + # if foldproteins: + + if X_cond != None: + pass + + tempname='temp' + pdb_file=foldandsavePDB ( + sequence=y_data_reversed[0], + filename_out=tempname, + num_cycle=num_cycle, + flag=flag, + # +++++++++++++++++++ + # prefix=prefix, + prefix=sample_dir, + ) + # + out_nam=iisample + # + out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' + write_fasta (y_data_reversed[0], out_nam_fasta) + # + out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' + shutil.copy (pdb_file, out_nam) #source, dest + pdb_file=out_nam + # + print (f"Properly named PDB file produced: {pdb_file}") + if IF_showfig==1: + #flag=1000 + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() + + if calc_error: + # only work for ModelA:SecStr + if CKeys['Problem_ID']==7: + get_Model_A_error (pdb_file, X[iisample], plotit=True) + else: + print ("Error calculation on the predicted results is not applicable") + + pdb_list.append(pdb_file) + + return pdb_list + + +# xbc=X_cond[iisample,:].cpu().detach().numpy() +# out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.2f" % xbc})+f'_{flag}_{steps}' +# tempname='temp' +# pdb_file=foldandsavePDB (sequence=y_data_reversed[0], +# filename_out=tempname, +# num_cycle=16, flag=flag) +# out_nam_fasta=f'{prefix}{out_nam}.fasta' + +# out_nam=f'{prefix}{out_nam}.pdb' + +# write_fasta (y_data_reversed[0], out_nam_fasta) + +# shutil.copy (pdb_file, out_nam) #source, dest +# pdb_file=out_nam +# print (f"Properly named PDB file produced: {pdb_file}") + +# view=show_pdb(pdb_file=pdb_file, flag=flag, +# show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) +# view.show() + +# if calc_error: +# get_Model_A_error (pdb_file, X[iisample], plotit=True) +# return pdb_file + +# +++ +def sample_sequence_omegafold_pLM_ModelA ( + model, + X=[[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X + flag=0, + cond_scales=1., + foldproteins=False, + X_string=None, # from text conditioning X_string + x_data=None, # from image conditioning x_data + skip_steps=0, + inpaint_images=None, # in formation Y data + inpaint_masks = None, + inpaint_resample_times = None, + init_images = None, + num_cycle=16, # for omegafolding + calc_error=False, # for check on folded results, not used for every case + # ++++++++++++++++++++++++++ + # tokenizers + tokenizer_X_forImageCondi=None, # for x_data + Xnormfac_forImageCondi=1., + tokenizer_X_forTextCondi=None, # for X if NEEDED only + Xnormfac_forTextCondi=1., + tokenizer_y=None, # for output Y + ynormfac=1, + # length + train_unet_number=1, + max_length_Y=1, # for Y, X_forImageCondi + max_text_len=1, # for X_forTextCondi + # other info + steps=None, + e=None, + sample_dir=None, + prefix=None, + IF_showfig=True, + CKeys=None, + # TBA to Model B + normalize_X_cond_to_one=False, + # ++ + pLM_Model=None, # pLM_Model, + pLM_Model_Name=None, # pLM_Model_Name, + image_channels=None, # image_channels, + pLM_alphabet=None, # esm_alphabet, +): + # ----------- + # steps=0 + # e=flag + + # -- + # print (f"Producing {len(X)} samples...") + # ++ + if X!=None: + print (f"Producing {len(X)} samples...from text conditioning X...") + lenn_val=len(X) + if X_string!=None: + lenn_val=len(X_string) + print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") + if x_data!=None: + print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") + lenn_val=len(x_data) + # print (x_data) + + print ('Device: ', model.device) + + pdb_list=[] + fasta_list=[] + + for iisample in range (lenn_val): + print(f"Working on {iisample}") + X_cond=None + + if X_string==None and X!=None: # for X channel + X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + if X_string!=None: # from raw text, ie., X_string: need tokenizer_X and Xnormfac + # - + # X = tokenizer_X.texts_to_sequences(X_string[iisample]) + # X= sequence.pad_sequences(X, maxlen=max_length, padding='post', truncating='post') + # X=np.array(X) + # X_cond=torch.from_numpy(X).float()/Xnormfac + # + + XX = tokenizer_X_forTextCondi.texts_to_sequences(X_string[iisample]) + XX = sequence.pad_sequences(XX, maxlen=max_text_len, padding='post', truncating='post') + XX = np.array(XX) + X_cond = torch.from_numpy(XX).float()/Xnormfac_forTextCondi + + print ('Tokenized and processed: ', X_cond) + + if X_cond!=None: + if normalize_X_cond_to_one: # used when there is constrain on X_cond.sum() + X_cond=X_cond/X_cond.sum() + + print ("Text conditoning used: ", X_cond, "...sum: ", X_cond.sum(), "cond scale: ", cond_scales) + else: + print ("Text conditioning used: None") + + # for now, assume image_condi and text_condi can be used at the same time + if tokenizer_X_forImageCondi==None: + # =========================================================== + # condi_image/seq needs no tokenization, like numbers: force_path + # only normalization needed + # Based on ModelB:Force_Path + if x_data!=None: + x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac_forImageCondi) + x_data_tokenized=x_data_tokenized.to(torch.float) + # + for debug: + if CKeys['Debug_TrainerPack']==1: + print("x_data_tokenized dim: ", x_data_tokenized.shape) + print("x_data_tokenized dtype: ", x_data_tokenized.dtype) + print("test: ", x_data_tokenized!=None) + else: + x_data_tokenized=None + # + for debug: + if CKeys['Debug_TrainerPack']==1: + print("x_data_tokenized and x_data: None") + + # model.sample:full arguments + # self, + # x=None, + # stop_at_unet_number=1, + # cond_scale=7.5, + # # ++ + # x_data=None, # image_condi data + # skip_steps=None, + # inpaint_images = None, + # inpaint_masks = None, + # inpaint_resample_times = 5, + # init_images = None, + # x_data_tokenized=None, + # tokenizer_X=None, + # Xnormfac=1., + # # -+ + # device=None, + # max_length=1., # for XandY data, in image/sequence format; NOT for text condition + # max_text_len=1., # for X data, in text format + # + result_embedding=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=None, + # ++ + x_data_tokenized=x_data_tokenized, + # + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images, + device=model.device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X, + Xnormfac=Xnormfac_forImageCondi, # Xnormfac, + # ynormfac=ynormfac, + max_length=max_length_Y, # for ImageCondi, max_length, + max_text_len=max_text_len, + ) + else: + # this is for model B in the future + # two channels should be provided: raw cond_img+img_tokenizer or tokenized_cond_img + # need to BE UPDATE and merge with code from + # fun.sample_sequence_omegafold_pLM_ModelB + # one branch is currently missing + result_embedding=model.sample ( + x=X_cond, + stop_at_unet_number=train_unet_number , + cond_scale=cond_scales , + x_data=x_data[iisample], + # ++ + x_data_tokenized=None, + # + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images, + device=model.device, + # ++++++++++++++++++++++++++ + tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X, + Xnormfac=Xnormfac_forImageCondi, # Xnormfac, + # ynormfac=ynormfac, + max_length=max_length_Y, # max_length, + max_text_len=max_text_len, + ) + + # # ------------------------------------------ + # result=model.sample ( + # X_cond, + # stop_at_unet_number=train_unet_number, + # cond_scale=cond_scales + # ) + + + # # ----------------------------------------------- + # result=torch.round(result*ynormfac) + # +++++++++++++++++++++++++++++++++++++++++++++++ + # ++ for pLM + # full record + # result_embedding as image.dim: [batch, channels, seq_len] + # result_tokens.dim: [batch, seq_len] + result_tokens,result_logits = convert_into_tokens( + pLM_Model, + result_embedding, + pLM_Model_Name, + ) + result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] + + # + for debug + print("result.dim: ", result.shape) + + fig=plt.figure() + plt.plot ( + result[0,0,:].cpu().detach().numpy(), + label= f'Predicted' + ) + #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}') + plt.legend() + outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" + #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}") + if IF_showfig==1: + plt.show () + else: + plt.savefig(outname, dpi=200) + plt.close() + + # # ---------------------------------------- + # plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted') + # plt.legend() + # outname = prefix+ f"sampld_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" + # plt.savefig(outname, dpi=200) + # plt.show () + +# # --------------------------------------------------------- +# to_rev=result[:,0,:] +# to_rev=to_rev.long().cpu().detach().numpy() +# print("to_rev.dim: ", to_rev.shape) +# y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + +# for iii in range (len(y_data_reversed)): +# y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + to_rev=result[:,0,:] + # the following fun decides ending automatically + y_data_reversed=decode_many_ems_token_rec_for_folding( + to_rev, + result_logits, + pLM_alphabet, + pLM_Model, + ) + if CKeys['Debug_TrainerPack']==3: + print("on y_data_reversed[0]: ", y_data_reversed[0]) + + + # + from Model B + ### reverse second structure input.... + + # + if X_cond != None: + # there is condi_text + if X_string!=None: + X_cond=torch.round(X_cond*Xnormfac_forTextCondi) + + to_rev=X_cond[:,:] + to_rev=to_rev.long().cpu().detach().numpy() + print ("to_rev.dim: ", to_rev.shape) + # -- + # X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) + # ++ + X_text_reversed=tokenizer_X_forTextCondi.sequences_to_texts (to_rev) + for iii in range (len(y_text_reversed)): + X_text_reversed[iii]=X_text_reversed[iii].upper().strip().replace(" ", "") + + if X_string==None: + # reverse this: X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) + X_text_reversed=X_cond + else: + X_text_reversed=None + + if x_data !=None: # there is condi_image + x_data_reversed=x_data #is already in sequence fromat.. + else: + x_data_reversed=None + + # summary + print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,\n predicted sequence: ", y_data_reversed) + + # + for debug + print("================================================") + print("foldproteins: ", foldproteins) + + if not foldproteins: + pdb_file=None + else: + # if foldproteins: + + if X_cond != None: + pass + + tempname='temp' + pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( + sequence=y_data_reversed[0], + filename_out=tempname, + num_cycle=num_cycle, + flag=flag, + # +++++++++++++++++++ + # prefix=prefix, + prefix=sample_dir, + ) + # + out_nam=iisample + # + # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' + # write_fasta (y_data_reversed[0], out_nam_fasta) + # + out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' + out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' + shutil.copy (pdb_file, out_nam) #source, dest + shutil.copy (fasta_file, out_nam_fasta) + # clean the slade to avoid mistakenly using the previous fasta file + os.remove (pdb_file) + os.remove (fasta_file) + # + pdb_file=out_nam + fasta_file=out_nam_fasta + # + pdb_list.append(pdb_file) + fasta_list.append(fasta_file) + # + print (f"Properly named PDB file produced: {pdb_file}") + if IF_showfig==1: + #flag=1000 + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() + + if calc_error: + if CKeys['Problem_ID']==7: + # only work for ModelA:SecStr + get_Model_A_error (pdb_file, X[iisample], plotit=True) + else: + print("Error calculation on the predicted results is not applicable...") + + + + return pdb_list, fasta_list + + +# + TBU +# ++++++++++++++++++++++++++++++++++++++++++++++++ +def sample_loop_omegafold_ModelA ( + model, + train_loader, + cond_scales=None, # [7.5], #list of cond scales - each sampled... + num_samples=None, # 2, #how many samples produced every time tested..... + timesteps=None, # 100, # not used + flag=None, # 0, + foldproteins=False, + # + cond_image=False, # use_text_embedd=True, + cond_text=True, + skip_steps=0, + # + max_text_len=None, + max_length=None, + # +++++++++++++++++++ + train_unet_number=1, + ynormfac=None, + prefix=None, + tokenizer_y=None, + Xnormfac_CondiText=1, + tokenizer_X_CondiText=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True +): + # ===================================================== + # sample # = num_samples*(# of mini-batches) + # ===================================================== + # steps=0 + # e=flag + # for item in train_loader: + for idx, item in enumerate(train_loader): + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + GT=y_train_batch.cpu().detach() + + GT= GT.unsqueeze(1) + if num_samples>y_train_batch.shape[0]: + print("Warning: sampling # > len(mini_batch)") + + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + X_train_batch_picked = X_train_batch[:num_samples,:] + print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) + + # loop over cond_scales:list + for iisample in range (len (cond_scales)): + + # ++ for model A + result=model.sample ( + x=X_train_batch_picked, + stop_at_unet_number=train_unet_number, + cond_scale=cond_scales[iisample], + # + skip_steps=skip_steps, + device=model.device, + # + max_length=max_length, + max_text_len=max_text_len, + ) + # # ++ for model B + # if use_text_embedd: + # result=model.sample ( + # # x= X_train_batch, + # x= X_train_batch_picked, + # stop_at_unet_number=train_unet_number , + # cond_scale=cond_scales[iisample], + # device=device, + # skip_steps=skip_steps + # ) + # else: + # result=model.sample ( + # x= None, + # # x_data_tokenized= X_train_batch, + # x_data_tokenized= X_train_batch_picked, + # stop_at_unet_number=train_unet_number , + # cond_scale=cond_scales[iisample], + # device=device, + # skip_steps=skip_steps + # ) + + result=torch.round(result*ynormfac) + GT=torch.round (GT*ynormfac) + + for samples in range (num_samples): + print ("sample ", samples+1, "out of ", num_samples) + + fig=plt.figure() + plt.plot ( + result[samples,0,:].cpu().detach().numpy(), + label= f'Predicted' + ) + plt.plot ( + GT[samples,0,:], + label= f'GT {0}' + ) + plt.legend() + outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + if IF_showfig==1: + plt.show() + else: + plt.savefig(outname, dpi=200) + plt.close () + + #reverse y sequence + to_rev=result[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + + y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + #reverse GT_y sequence + to_rev=GT[:,0,:] + to_rev=to_rev.long().cpu().detach().numpy() + + GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + + for iii in range (len(y_data_reversed)): + GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") + + ### reverse second structure input.... + # pay attension to the shape of Xnormfac + # - + # to_rev=torch.round (X_train_batch[:,:]*Xnormfac_CondiText) + # + + to_rev=torch.round (X_train_batch[:,:]*torch.FloatTensor(Xnormfac_CondiText).to(model.device)) + to_rev=to_rev.long().cpu().detach().numpy() + + # ++ different input + if CKeys['Debug_TrainerPack']==1: + print("tokenizer_X_CondiText: ", tokenizer_X_CondiText) + print("Xnormfac_CondiText: ", Xnormfac_CondiText) + + if tokenizer_X_CondiText!=None: + X_data_reversed=tokenizer_X_CondiText.sequences_to_texts (to_rev) + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + else: + X_data_reversed=to_rev.copy() + # + for debug + if CKeys['Debug_TrainerPack']==1: + print("X_data_reversed: ", X_data_reversed) + + + print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples]) + print (f"Ground truth: {GT_y_data_reversed[samples]}") + + if foldproteins: + xbc=X_train_batch[samples,:].cpu().detach().numpy() + out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) + tempname='temp' + pdb_file=foldandsavePDB ( + sequence=y_data_reversed[samples], + filename_out=tempname, + num_cycle=16, flag=flag, + # +++++++++++++++++++ + prefix=prefix + ) + + # #out_nam=f'{prefix}{out_nam}.pdb' + # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' + # ------------------------------------------------------ + # sometime, this name below can get too long to fit + # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb' + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add a way to save the sampling name and results + # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' + out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' + + if CKeys['Debug_TrainerPack']==1: + print("pdb_file: ", pdb_file) + print("out_nam: ", out_nam) + + print (f'Original PDB: {pdb_file} OUT: {out_nam}') + shutil.copy (pdb_file, out_nam) #source, dest + # + + with open(out_nam_inX, "w") as inX_file: + inX_file.write(f'{X_data_reversed[samples]}\n') + + pdb_file=out_nam + print (f"Properly named PDB file produced: {pdb_file}") + print (f"input X for sampling stored: {pdb_file}") + + if IF_showfig==1: + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() + +# steps=steps+1 + +# if steps>num_samples: +# break + +# + TBU +# ++++++++++++++++++++++++++++++++++++++++++++++++ +def sample_loop_omegafold_pLM_ModelA ( + model, + train_loader, + cond_scales=None, # [7.5], #list of cond scales - each sampled... + num_samples=None, # 2, #how many samples produced every time tested..... + timesteps=None, # 100, # not used + flag=None, # 0, + foldproteins=False, + # + cond_image=False, # use_text_embedd=True, + cond_text=True, + skip_steps=0, + # + max_text_len=None, + max_length=None, + # +++++++++++++++++++ + train_unet_number=1, + ynormfac=None, + prefix=None, + tokenizer_y=None, + Xnormfac_CondiText=1, + tokenizer_X_CondiText=None, + # ++ + CKeys=None, + sample_dir=None, + steps=None, + e=None, + IF_showfig=True, # effective only after foldproteins=True + # ++ for pLM + pLM_Model=None, + pLM_Model_Name=None, + image_channels=None, + pLM_alphabet=None, + # ++ for on-fly check: for SecStr only + calc_error=False, # for check on folded results, not used for every case +): + # ===================================================== + # sample # = num_samples*(# of mini-batches) + # ===================================================== + # steps=0 + # e=flag + # for item in train_loader: + for idx, item in enumerate(train_loader): + + X_train_batch= item[0].to(device) + y_train_batch=item[1].to(device) + + GT=y_train_batch.cpu().detach() + + GT= GT.unsqueeze(1) + if num_samples>y_train_batch.shape[0]: + print("Warning: sampling # > len(mini_batch)") + + num_samples = min (num_samples,y_train_batch.shape[0] ) + print (f"Producing {num_samples} samples...") + X_train_batch_picked = X_train_batch[:num_samples,:] + print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) + + # loop over cond_scales:list + for iisample in range (len (cond_scales)): + + # ++ for model A + result_embedding = model.sample ( + x=X_train_batch_picked, + stop_at_unet_number=train_unet_number, + cond_scale=cond_scales[iisample], + # + skip_steps=skip_steps, + device=model.device, + # + max_length=max_length, + max_text_len=max_text_len, + # + x_data=None, + x_data_tokenized=None, + # + tokenizer_X=tokenizer_X_CondiText, + Xnormfac=Xnormfac_CondiText, + ) + # # ++ for model B + # if use_text_embedd: + # result=model.sample ( + # # x= X_train_batch, + # x= X_train_batch_picked, + # stop_at_unet_number=train_unet_number , + # cond_scale=cond_scales[iisample], + # device=device, + # skip_steps=skip_steps + # ) + # else: + # result=model.sample ( + # x= None, + # # x_data_tokenized= X_train_batch, + # x_data_tokenized= X_train_batch_picked, + # stop_at_unet_number=train_unet_number , + # cond_scale=cond_scales[iisample], + # device=device, + # skip_steps=skip_steps + # ) + + # ++ for pLM: + # full record + # result_embedding as image.dim: [batch, channels, seq_len] + # result_tokens.dim: [batch, seq_len] + result_tokens,result_logits = convert_into_tokens( + pLM_Model, + result_embedding, + pLM_Model_Name, + ) + + # # -------------------------------------------- + # result=torch.round(result*ynormfac) + # GT=torch.round (GT*ynormfac) + # ++++++++++++++++++++++++++++++++++++++++++++ + result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len] + +# # --------------------------------------------------------- +# #reverse y sequence +# to_rev=result[:,0,:] +# to_rev=to_rev.long().cpu().detach().numpy() + +# y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + +# for iii in range (len(y_data_reversed)): +# y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + to_rev=result[:,0,:] # token (batch,seq_len) + y_data_reversed=decode_many_ems_token_rec_for_folding( + to_rev, + result_logits, + pLM_alphabet, + pLM_Model, + ) + if CKeys['Debug_TrainerPack']==3: + print("on y_data_reversed[0]: ", y_data_reversed[0]) + + +# # ----------------------------------------------------------- +# #reverse GT_y sequence +# to_rev=GT[:,0,:] +# to_rev=to_rev.long().cpu().detach().numpy() + +# GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) + +# for iii in range (len(y_data_reversed)): +# GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + #reverse GT_y sequence + # GT should be SAFE to reverse + to_rev=GT[:,0,:] + GT_y_data_reversed=decode_many_ems_token_rec( + to_rev, + pLM_alphabet, + ) + + ### reverse second structure input.... + # pay attension to the shape of Xnormfac + # - + # to_rev=torch.round (X_train_batch[:,:]*Xnormfac_CondiText) + # + + # print("X_train_batch", X_train_batch) + # print("Xnormfac_CondiText: ", Xnormfac_CondiText) + + # to_rev=torch.round (X_train_batch[:,:]*torch.FloatTensor(Xnormfac_CondiText).to(model.device)) + # print("X_train_batch: ", X_train_batch[:,:]) + # print("torch.tensor(Xnormfac_CondiText): ", torch.tensor(Xnormfac_CondiText)) + to_rev=X_train_batch[:,:]*torch.tensor(Xnormfac_CondiText).to(model.device) + # print("to_rev ", to_rev) + # # -: convert into int64 + # to_rev=to_rev.long().cpu().detach().numpy() + # +: just float + to_rev=to_rev.cpu().detach().numpy() + # print("to_rev 2", to_rev) + + # ++ different input + if CKeys['Debug_TrainerPack']==1: + print("tokenizer_X_CondiText: ", tokenizer_X_CondiText) + print("Xnormfac_CondiText: ", Xnormfac_CondiText) + + if tokenizer_X_CondiText!=None: + # round the number into tokens + to_rev = np.round(to_rev) + X_data_reversed=tokenizer_X_CondiText.sequences_to_texts (to_rev) + for iii in range (len(y_data_reversed)): + X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") + else: + X_data_reversed=to_rev.copy() + # + for debug + if CKeys['Debug_TrainerPack']==1: + print("X_data_reversed: ", X_data_reversed) + print("X_data_reversed.dim: ", X_data_reversed.shape) + + for samples in range (num_samples): + print ("sample ", samples+1, "out of ", num_samples) + + fig=plt.figure() + plt.plot ( + result[samples,0,:].cpu().detach().numpy(), + label= f'Predicted' + ) + plt.plot ( + GT[samples,0,:], + label= f'GT {0}' + ) + plt.legend() + outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + if IF_showfig==1: + plt.show() + else: + plt.savefig(outname, dpi=200) + plt.close () + + + print (f"For input in dataloader: {X_train_batch[samples,:].cpu().detach().numpy()} or \n recovered input {X_data_reversed[samples]}") + print (f"predicted sequence: {y_data_reversed[samples]}") + print (f"Ground truth: {GT_y_data_reversed[samples]}") + + if foldproteins: + # check whether the predicted sequence is valid + if len(y_data_reversed[samples])>0: + # # -- + # xbc=X_train_batch[samples,:].cpu().detach().numpy() + # # out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) + # out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) + # ++ + xbc=X_data_reversed[samples] + out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) + # + tempname='temp' + pdb_file,fasta_file=foldandsavePDB_pdb_fasta ( + sequence=y_data_reversed[samples], + filename_out=tempname, + num_cycle=16, flag=flag, + # +++++++++++++++++++ + prefix=prefix + ) + + # #out_nam=f'{prefix}{out_nam}.pdb' + # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' + # ------------------------------------------------------ + # sometime, this name below can get too long to fit + # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb' + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add a way to save the sampling name and results + # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" + out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' + out_nam_seq=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.fasta' + out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' + + if CKeys['Debug_TrainerPack']==1: + print("pdb_file: ", pdb_file) + print("out_nam: ", out_nam) + + print (f'Original PDB: {pdb_file} OUT: {out_nam}') + shutil.copy (pdb_file, out_nam) #source, dest + shutil.copy (fasta_file, out_nam_seq) + # + + with open(out_nam_inX, "w") as inX_file: + # inX_file.write(f'{X_data_reversed[samples]}\n') + inX_file.write(out_nam_content) + # clean the slade to avoid mistakenly using the previous fasta file + os.remove (pdb_file) + os.remove (fasta_file) + + + pdb_file=out_nam + print (f"Properly named PDB file produced: {pdb_file}") + print (f"input X for sampling stored: {pdb_file}") + + if IF_showfig==1: + view=show_pdb( + pdb_file=pdb_file, + flag=flag, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color + ) + view.show() + + if calc_error: + print('On-fly check...') + if CKeys['Problem_ID']==7: + # only work for ModelA:SecStr + get_Model_A_error (pdb_file, X_data_reversed[samples], plotit=True) + else: + print("Error calculation on the predicted results is not applicable...") + + + else: + print("The predicted sequence is EMPTY...") +