import os
import time
import copy
from pathlib import Path
from math import ceil
from contextlib import contextmanager, nullcontext
from functools import partial, wraps
from collections.abc import Iterable

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.cuda.amp import autocast, GradScaler

import pytorch_warmup as warmup

import shutil

import esm
from einops import rearrange

from packaging import version
__version__ = '1.9.3'

device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu"
)

import matplotlib.pyplot as plt

def cycle(dl):
    while True:
        for data in dl:
            yield data
            
from packaging import version

import numpy as np

from ema_pytorch import EMA

from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs

from fsspec.core import url_to_fs
from fsspec.implementations.local import LocalFileSystem

# # --
# from PD_SpLMxDiff.ModelPack import resize_image_to,ProteinDesigner_B 
# from PD_SpLMxDiff.UtilityPack import get_Model_A_error, convert_into_tokens
# from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec,decode_many_ems_token_rec
# from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec_for_folding,decode_many_ems_token_rec_for_folding

# from PD_SpLMxDiff.UtilityPack import decode_one_ems_token_rec_for_folding_with_mask,decode_many_ems_token_rec_for_folding_with_mask,read_mask_from_input
# from PD_SpLMxDiff.UtilityPack import get_DSSP_result, string_diff
# from PD_SpLMxDiff.DataSetPack import pad_a_np_arr
# ++
from PD_pLMProbXDiff.ModelPack import (
    resize_image_to, ProteinDesigner_B,
)
from PD_pLMProbXDiff.UtilityPack import (
    get_Model_A_error, convert_into_tokens,convert_into_tokens_using_prob,
    decode_one_ems_token_rec, decode_many_ems_token_rec,
    decode_one_ems_token_rec_for_folding, 
    decode_many_ems_token_rec_for_folding,
    decode_one_ems_token_rec_for_folding_with_mask,
    decode_many_ems_token_rec_for_folding_with_mask,
    read_mask_from_input,
    get_DSSP_result, 
    string_diff,
    load_in_pLM,
)
from PD_pLMProbXDiff.DataSetPack import (
    pad_a_np_arr
)

# loss function 
criterion_MSE_sum =  nn.MSELoss(reduction='sum')
criterion_MAE_sum =  nn.L1Loss(reduction='sum')

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def cast_tuple(val, length = 1):
    if isinstance(val, list):
        val = tuple(val)
    
    return val if isinstance(val, tuple) else ((val,) * length)

def find_first(fn, arr):
    for ind, el in enumerate(arr):
        if fn(el):
            return ind
    return -1

def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

def string_begins_with(prefix, str):
    return str.startswith(prefix)

def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
    return kwargs_without_prefix, kwargs

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# url to fs, bucket, path - for checkpointing to cloud

def url_to_bucket(url):
    if '://' not in url:
        return url

    _, suffix = url.split('://')

    if prefix in {'gs', 's3'}:
        return suffix.split('/')[0]
    else:
        raise ValueError(f'storage type prefix "{prefix}" is not supported yet')

# decorators

def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

def cast_torch_tensor(fn, cast_fp16 = False):
    @wraps(fn)
    def inner(model, *args, **kwargs):
        device = kwargs.pop('_device', model.device)
        cast_device = kwargs.pop('_cast_device', True)

        should_cast_fp16 = cast_fp16 and model.cast_half_at_training

        kwargs_keys = kwargs.keys()
        all_args = (*args, *kwargs.values())
        split_kwargs_index = len(all_args) - len(kwargs_keys)
        all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))

        if cast_device:
            all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))

        if should_cast_fp16:
            all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args))

        args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
        kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))

        out = fn(model, *args, **kwargs)
        return out
    return inner

# gradient accumulation functions

def split_iterable(it, split_size):
    accum = []
    for ind in range(ceil(len(it) / split_size)):
        start_index = ind * split_size
        accum.append(it[start_index: (start_index + split_size)])
    return accum

def split(t, split_size = None):
    if not exists(split_size):
        return t

    if isinstance(t, torch.Tensor):
        return t.split(split_size, dim = 0)

    if isinstance(t, Iterable):
        return split_iterable(t, split_size)

    return TypeError

def find_first(cond, arr):
    for el in arr:
        if cond(el):
            return el
    return None

def split_args_and_kwargs(*args, split_size = None, **kwargs):
    all_args = (*args, *kwargs.values())
    len_all_args = len(all_args)
    first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
    assert exists(first_tensor)

    batch_size = len(first_tensor)
    split_size = default(split_size, batch_size)
    num_chunks = ceil(batch_size / split_size)

    dict_len = len(kwargs)
    dict_keys = kwargs.keys()
    split_kwargs_index = len_all_args - dict_len

    split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
    chunk_sizes = tuple(map(len, split_all_args[0]))

    for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
        chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
        chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
        chunk_size_frac = chunk_size / batch_size
        yield chunk_size_frac, (chunked_args, chunked_kwargs)

# imagen trainer

def imagen_sample_in_chunks(fn):
    @wraps(fn)
    def inner(self, *args, max_batch_size = None, **kwargs):
        if not exists(max_batch_size):
            return fn(self, *args, **kwargs)

        if self.imagen.unconditional:
            batch_size = kwargs.get('batch_size')
            batch_sizes = num_to_groups(batch_size, max_batch_size)
            outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
        else:
            outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]

        if isinstance(outputs[0], torch.Tensor):
            return torch.cat(outputs, dim = 0)

        return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs))))

    return inner


def restore_parts(state_dict_target, state_dict_from):
    for name, param in state_dict_from.items():

        if name not in state_dict_target:
            continue

        if param.size() == state_dict_target[name].size():
            state_dict_target[name].copy_(param)
        else:
            print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")

    return state_dict_target

class ImagenTrainer(nn.Module):
    locked = False

    def __init__(
        self,
        #imagen = None,
        model = None,
        
        imagen_checkpoint_path = None,
        use_ema = True,
        lr = 1e-4,
        eps = 1e-8,
        beta1 = 0.9,
        beta2 = 0.99,
        max_grad_norm = None,
        group_wd_params = True,
        warmup_steps = None,
        cosine_decay_max_steps = None,
        only_train_unet_number = None,
        fp16 = False,
        precision = None,
        split_batches = True,
        dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
        verbose = True,
        split_valid_fraction = 0.025,
        split_valid_from_train = False,
        split_random_seed = 42,
        checkpoint_path = None,
        checkpoint_every = None,
        checkpoint_fs = None,
        fs_kwargs: dict = None,
        max_checkpoints_keep = 20,
        # +++++++++++++++++++++
        CKeys=None,
        #
        **kwargs
    ):
        super().__init__()
        assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
        assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'

        # determine filesystem, using fsspec, for saving to local filesystem or cloud

        self.fs = checkpoint_fs

        if not exists(self.fs):
            fs_kwargs = default(fs_kwargs, {})
            self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
        
        # # -----------------------------------
        # # from MJB
        # assert isinstance(model.imagen, (ProteinDesigner_B))
        # modified by BN
        # ++: try this trainer for all models
        # assert isinstance(model, (ProteinDesigner_B))
        
        # +++++++++++++++++++++++++
        self.CKeys = CKeys
        
        ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)

         
        self.imagen = model.imagen
       
        

        self.model=model
        self.is_elucidated = self.model.is_elucidated 
        # create accelerator instance

        accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)

        assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
        accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')

        self.accelerator = Accelerator(**{
            'split_batches': split_batches,
            'mixed_precision': accelerator_mixed_precision,
            'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)]
        , **accelerate_kwargs})

        ImagenTrainer.locked = self.is_distributed

        # cast data to fp16 at training time if needed

        self.cast_half_at_training = accelerator_mixed_precision == 'fp16'

        # grad scaler must be managed outside of accelerator

        grad_scaler_enabled = fp16
   
        self.num_unets = len(self.imagen.unets)

        self.use_ema = use_ema and self.is_main
        self.ema_unets = nn.ModuleList([])

        # keep track of what unet is being trained on
        # only going to allow 1 unet training at a time

        self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on

        # data related functions

        self.train_dl_iter = None
        self.train_dl = None

        self.valid_dl_iter = None
        self.valid_dl = None

        self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names

        # auto splitting validation from training, if dataset is passed in

        self.split_valid_from_train = split_valid_from_train

        assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
        self.split_valid_fraction = split_valid_fraction
        self.split_random_seed = split_random_seed

        # be able to finely customize learning rate, weight decay
        # per unet

        lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))

        for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
            optimizer = Adam(
                unet.parameters(),
                lr = unet_lr,
                eps = unet_eps,
                betas = (beta1, beta2),
                **kwargs
            )

            if self.use_ema:
                self.ema_unets.append(EMA(unet, **ema_kwargs))

            scaler = GradScaler(enabled = grad_scaler_enabled)

            scheduler = warmup_scheduler = None

            if exists(unet_cosine_decay_max_steps):
                scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)

            if exists(unet_warmup_steps):
                warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps)

                if not exists(scheduler):
                    scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)

            # set on object

            setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
            setattr(self, f'scaler{ind}', scaler)
            setattr(self, f'scheduler{ind}', scheduler)
            setattr(self, f'warmup{ind}', warmup_scheduler)

        # gradient clipping if needed

        self.max_grad_norm = max_grad_norm

        # step tracker and misc

        self.register_buffer('steps', torch.tensor([0] * self.num_unets))

        self.verbose = verbose

        # automatic set devices based on what accelerator decided

        self.imagen.to(self.device)
        self.to(self.device)

        # checkpointing

        assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
        self.checkpoint_path = checkpoint_path
        self.checkpoint_every = checkpoint_every
        self.max_checkpoints_keep = max_checkpoints_keep

        self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main

        if exists(checkpoint_path) and self.can_checkpoint:
            bucket = url_to_bucket(checkpoint_path)

            if not self.fs.exists(bucket):
                self.fs.mkdir(bucket)

            self.load_from_checkpoint_folder()

        # only allowing training for unet

        self.only_train_unet_number = only_train_unet_number
        self.validate_and_set_unet_being_trained(only_train_unet_number)

    # computed values

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    @property
    def unwrapped_unet(self):
        return self.accelerator.unwrap_model(self.unet_being_trained)

    # optimizer helper functions

    def get_lr(self, unet_number):
        self.validate_unet_number(unet_number)
        unet_index = unet_number - 1

        optim = getattr(self, f'optim{unet_index}')

        return optim.param_groups[0]['lr']

    # function for allowing only one unet from being trained at a time

    def validate_and_set_unet_being_trained(self, unet_number = None):
        if exists(unet_number):
            self.validate_unet_number(unet_number)

        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'

        self.only_train_unet_number = unet_number
        self.imagen.only_train_unet_number = unet_number

        if not exists(unet_number):
            return

        self.wrap_unet(unet_number)

    def wrap_unet(self, unet_number):
        if hasattr(self, 'one_unet_wrapped'):
            return

        unet = self.imagen.get_unet(unet_number)
        self.unet_being_trained = self.accelerator.prepare(unet)
        unet_index = unet_number - 1

        optimizer = getattr(self, f'optim{unet_index}')
        scheduler = getattr(self, f'scheduler{unet_index}')

        optimizer = self.accelerator.prepare(optimizer)

        if exists(scheduler):
            scheduler = self.accelerator.prepare(scheduler)

        setattr(self, f'optim{unet_index}', optimizer)
        setattr(self, f'scheduler{unet_index}', scheduler)

        self.one_unet_wrapped = True

    # hacking accelerator due to not having separate gradscaler per optimizer

    def set_accelerator_scaler(self, unet_number):
        unet_number = self.validate_unet_number(unet_number)
        scaler = getattr(self, f'scaler{unet_number - 1}')

        self.accelerator.scaler = scaler
        for optimizer in self.accelerator._optimizers:
            optimizer.scaler = scaler

    # helper print

    def print(self, msg):
        if not self.is_main:
            return

        if not self.verbose:
            return

        return self.accelerator.print(msg)

    # validating the unet number

    def validate_unet_number(self, unet_number = None):
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
        return unet_number

    # number of training steps taken

    def num_steps_taken(self, unet_number = None):
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        return self.steps[unet_number - 1].item()

    def print_untrained_unets(self):
        print_final_error = False

        for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
            if steps > 0 or isinstance(unet, NullUnet):
                continue

            self.print(f'unet {ind + 1} has not been trained')
            print_final_error = True

        if print_final_error:
            self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')

    # data related functions

    def add_train_dataloader(self, dl = None):
        if not exists(dl):
            return

        assert not exists(self.train_dl), 'training dataloader was already added'
        self.train_dl = self.accelerator.prepare(dl)

    def add_valid_dataloader(self, dl):
        if not exists(dl):
            return

        assert not exists(self.valid_dl), 'validation dataloader was already added'
        self.valid_dl = self.accelerator.prepare(dl)

    def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        assert not exists(self.train_dl), 'training dataloader was already added'

        valid_ds = None
        if self.split_valid_from_train:
            train_size = int((1 - self.split_valid_fraction) * len(ds))
            valid_size = len(ds) - train_size

            ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
            self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')

        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.train_dl = self.accelerator.prepare(dl)

        if not self.split_valid_from_train:
            return

        self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)

    def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        assert not exists(self.valid_dl), 'validation dataloader was already added'

        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.valid_dl = self.accelerator.prepare(dl)

    def create_train_iter(self):
        assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'

        if exists(self.train_dl_iter):
            return

        self.train_dl_iter = cycle(self.train_dl)

    def create_valid_iter(self):
        assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'

        if exists(self.valid_dl_iter):
            return

        self.valid_dl_iter = cycle(self.valid_dl)

    def train_step(self, unet_number = None, **kwargs):
        self.create_train_iter()
        loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
        self.update(unet_number = unet_number)
        return loss

    @torch.no_grad()
    @eval_decorator
    def valid_step(self, **kwargs):
        self.create_valid_iter()

        context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext

        with context():
            loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
        return loss

    def step_with_dl_iter(self, dl_iter, **kwargs):
        dl_tuple_output = cast_tuple(next(dl_iter))
        model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
        loss = self.forward(**{**kwargs, **model_input})
        return loss

    # checkpointing functions

    @property
    def all_checkpoints_sorted(self):
        glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
        checkpoints = self.fs.glob(glob_pattern)
        sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
        return sorted_checkpoints

    def load_from_checkpoint_folder(self, last_total_steps = -1):
        if last_total_steps != -1:
            filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
            self.load(filepath)
            return

        sorted_checkpoints = self.all_checkpoints_sorted

        if len(sorted_checkpoints) == 0:
            self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
            return

        last_checkpoint = sorted_checkpoints[0]
        self.load(last_checkpoint)

    def save_to_checkpoint_folder(self):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        total_steps = int(self.steps.sum().item())
        filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')

        self.save(filepath)

        if self.max_checkpoints_keep <= 0:
            return

        sorted_checkpoints = self.all_checkpoints_sorted
        checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]

        for checkpoint in checkpoints_to_discard:
            self.fs.rm(checkpoint)

    # saving and loading functions

    def save(
        self,
        path,
        overwrite = True,
        without_optim_and_sched = False,
        **kwargs
    ):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        fs = self.fs

        assert not (fs.exists(path) and not overwrite)

        self.reset_ema_unets_all_one_device()

        save_obj = dict(
            model = self.imagen.state_dict(),
            version = __version__,
            steps = self.steps.cpu(),
            **kwargs
        )

        save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()

        for ind in save_optim_and_sched_iter:
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            if exists(scheduler):
                save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}

            if exists(warmup_scheduler):
                save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}

            save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}

        if self.use_ema:
            save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}

        # determine if imagen config is available

        if hasattr(self.imagen, '_config'):
            self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""')

            save_obj = {
                **save_obj,
                'imagen_type': 'elucidated' if self.is_elucidated else 'original',
                'imagen_params': self.imagen._config
            }

        #save to path

        with fs.open(path, 'wb') as f:
            torch.save(save_obj, f)

        self.print(f'checkpoint saved to {path}')

    def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
        fs = self.fs

        if noop_if_not_exist and not fs.exists(path):
            self.print(f'trainer checkpoint not found at {str(path)}')
            return

        assert fs.exists(path), f'{path} does not exist'

        self.reset_ema_unets_all_one_device()

        # to avoid extra GPU memory usage in main process when using Accelerate

        with fs.open(path) as f:
            loaded_obj = torch.load(f, map_location='cpu')

        if version.parse(__version__) != version.parse(loaded_obj['version']):
            self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')

        try:
            self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
        except RuntimeError:
            print("Failed loading state dict. Trying partial load")
            self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
                                                      loaded_obj['model']))

        if only_model:
            return loaded_obj

        self.steps.copy_(loaded_obj['steps'])

        for ind in range(0, self.num_unets):
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            if exists(scheduler) and scheduler_key in loaded_obj:
                scheduler.load_state_dict(loaded_obj[scheduler_key])

            if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
                warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])

            if exists(optimizer):
                try:
                    optimizer.load_state_dict(loaded_obj[optimizer_key])
                    scaler.load_state_dict(loaded_obj[scaler_key])
                except:
                    self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')

        if self.use_ema:
            assert 'ema' in loaded_obj
            try:
                self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
            except RuntimeError:
                print("Failed loading state dict. Trying partial load")
                self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
                                                             loaded_obj['ema']))

        self.print(f'checkpoint loaded from {path}')
        return loaded_obj

    # managing ema unets and their devices

    @property
    def unets(self):
        return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

    def get_ema_unet(self, unet_number = None):
        if not self.use_ema:
            return

        unet_number = self.validate_unet_number(unet_number)
        index = unet_number - 1

        if isinstance(self.unets, nn.ModuleList):
            unets_list = [unet for unet in self.ema_unets]
            delattr(self, 'ema_unets')
            self.ema_unets = unets_list

        if index != self.ema_unet_being_trained_index:
            for unet_index, unet in enumerate(self.ema_unets):
                unet.to(self.device if unet_index == index else 'cpu')

        self.ema_unet_being_trained_index = index
        return self.ema_unets[index]

    def reset_ema_unets_all_one_device(self, device = None):
        if not self.use_ema:
            return

        device = default(device, self.device)
        self.ema_unets = nn.ModuleList([*self.ema_unets])
        self.ema_unets.to(device)

        self.ema_unet_being_trained_index = -1

    @torch.no_grad()
    @contextmanager
    def use_ema_unets(self):
        if not self.use_ema:
            output = yield
            return output

        self.reset_ema_unets_all_one_device()
        self.imagen.reset_unets_all_one_device()

        self.unets.eval()

        trainable_unets = self.imagen.unets
        self.imagen.unets = self.unets                  # swap in exponential moving averaged unets for sampling

        output = yield

        self.imagen.unets = trainable_unets             # restore original training unets

        # cast the ema_model unets back to original device
        for ema in self.ema_unets:
            ema.restore_ema_model_device()

        return output

    def print_unet_devices(self):
        self.print('unet devices:')
        for i, unet in enumerate(self.imagen.unets):
            device = next(unet.parameters()).device
            self.print(f'\tunet {i}: {device}')

        if not self.use_ema:
            return

        self.print('\nema unet devices:')
        for i, ema_unet in enumerate(self.ema_unets):
            device = next(ema_unet.parameters()).device
            self.print(f'\tema unet {i}: {device}')

    # overriding state dict functions

    def state_dict(self, *args, **kwargs):
        self.reset_ema_unets_all_one_device()
        return super().state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        self.reset_ema_unets_all_one_device()
        return super().load_state_dict(*args, **kwargs)

    # encoding text functions

    def encode_text(self, text, **kwargs):
        return self.imagen.encode_text(text, **kwargs)

    # forwarding functions and gradient step updates

    def update(self, unet_number = None):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        index = unet_number - 1
        unet = self.unet_being_trained

        optimizer = getattr(self, f'optim{index}')
        scaler = getattr(self, f'scaler{index}')
        scheduler = getattr(self, f'scheduler{index}')
        warmup_scheduler = getattr(self, f'warmup{index}')

        # set the grad scaler on the accelerator, since we are managing one per u-net

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)

        optimizer.step()
        optimizer.zero_grad()

        if self.use_ema:
            ema_unet = self.get_ema_unet(unet_number)
            ema_unet.update()

        # scheduler, if needed

        maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()

        with maybe_warmup_context:
            if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
                scheduler.step()

        self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))

        if not exists(self.checkpoint_path):
            return

        total_steps = int(self.steps.sum().item())

        if total_steps % self.checkpoint_every:
            return

        self.save_to_checkpoint_folder()

    @torch.no_grad()
    @cast_torch_tensor
    @imagen_sample_in_chunks
    def sample(self, *args, **kwargs):
        context = nullcontext if  kwargs.pop('use_non_ema', False) else self.use_ema_unets

        self.print_untrained_unets()        
        
        if not self.is_main:
            kwargs['use_tqdm'] = False

        with context():
            output = self.imagen.sample(*args, device = self.device, **kwargs)

        return output

    @partial(cast_torch_tensor, cast_fp16 = True)
    def forward(
        self,
        *args,
        unet_number = None,
        max_batch_size = None,
        **kwargs
    ):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'

        total_loss = 0.
        
        
        # + for debug
        if self.CKeys['Debug_TrainerPack']==1:
            print("In Trainer:Forward, check inputs:")
            print('args: ', len(args))
            print('args in:',args[0].shape)
            print('kwargs: ', kwargs.keys())
        for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
            # + for debug
            if self.CKeys['Debug_TrainerPack']==1:
                print("after chunks,...")
                print('chun_frac: ', chunk_size_frac)
                print('chun_args: ', chunked_args)
                print('chun_kwargs: ', chunked_kwargs)
            
            with self.accelerator.autocast():
                loss = self.model(
                    *chunked_args, 
                    unet_number = unet_number, 
                    **chunked_kwargs
                )
                loss = loss * chunk_size_frac
            
            # + for debug
            if self.CKeys['Debug_TrainerPack']==1:
                print('part chun loss: ', loss)

            total_loss += loss#.item()

            if self.training:
                self.accelerator.backward(loss)

        return total_loss
    
# ========================================================
# 
class ImagenTrainer_ModelB(nn.Module):
    locked = False

    def __init__(
        self,
        #imagen = None,
        model = None,
        
        imagen_checkpoint_path = None,
        use_ema = True,
        lr = 1e-4,
        eps = 1e-8,
        beta1 = 0.9,
        beta2 = 0.99,
        max_grad_norm = None,
        group_wd_params = True,
        warmup_steps = None,
        cosine_decay_max_steps = None,
        only_train_unet_number = None,
        fp16 = False,
        precision = None,
        split_batches = True,
        dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
        verbose = True,
        split_valid_fraction = 0.025,
        split_valid_from_train = False,
        split_random_seed = 42,
        checkpoint_path = None,
        checkpoint_every = None,
        checkpoint_fs = None,
        fs_kwargs: dict = None,
        max_checkpoints_keep = 20,
        # +++++++++++++++++++++
        CKeys=None,
        #
        **kwargs
    ):
        super().__init__()
        assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
        assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'

        # determine filesystem, using fsspec, for saving to local filesystem or cloud

        self.fs = checkpoint_fs

        if not exists(self.fs):
            fs_kwargs = default(fs_kwargs, {})
            self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
        
        # # -----------------------------------
        # # from MJB
        # assert isinstance(model.imagen, (ProteinDesigner_B))
        # modified by BN
        # ++
        assert isinstance(model, (ProteinDesigner_B))
        
        # +++++++++++++++++++++++++
        self.CKeys = CKeys
        
        ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)

         
        self.imagen = model.imagen
       
        

        self.model=model
        self.is_elucidated = self.model.is_elucidated 
        # create accelerator instance

        accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)

        assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
        accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')

        self.accelerator = Accelerator(**{
            'split_batches': split_batches,
            'mixed_precision': accelerator_mixed_precision,
            'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)]
        , **accelerate_kwargs})

        ImagenTrainer.locked = self.is_distributed

        # cast data to fp16 at training time if needed

        self.cast_half_at_training = accelerator_mixed_precision == 'fp16'

        # grad scaler must be managed outside of accelerator

        grad_scaler_enabled = fp16
   
        self.num_unets = len(self.imagen.unets)

        self.use_ema = use_ema and self.is_main
        self.ema_unets = nn.ModuleList([])

        # keep track of what unet is being trained on
        # only going to allow 1 unet training at a time

        self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on

        # data related functions

        self.train_dl_iter = None
        self.train_dl = None

        self.valid_dl_iter = None
        self.valid_dl = None

        self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names

        # auto splitting validation from training, if dataset is passed in

        self.split_valid_from_train = split_valid_from_train

        assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
        self.split_valid_fraction = split_valid_fraction
        self.split_random_seed = split_random_seed

        # be able to finely customize learning rate, weight decay
        # per unet

        lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))

        for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
            optimizer = Adam(
                unet.parameters(),
                lr = unet_lr,
                eps = unet_eps,
                betas = (beta1, beta2),
                **kwargs
            )

            if self.use_ema:
                self.ema_unets.append(EMA(unet, **ema_kwargs))

            scaler = GradScaler(enabled = grad_scaler_enabled)

            scheduler = warmup_scheduler = None

            if exists(unet_cosine_decay_max_steps):
                scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)

            if exists(unet_warmup_steps):
                warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps)

                if not exists(scheduler):
                    scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)

            # set on object

            setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
            setattr(self, f'scaler{ind}', scaler)
            setattr(self, f'scheduler{ind}', scheduler)
            setattr(self, f'warmup{ind}', warmup_scheduler)

        # gradient clipping if needed

        self.max_grad_norm = max_grad_norm

        # step tracker and misc

        self.register_buffer('steps', torch.tensor([0] * self.num_unets))

        self.verbose = verbose

        # automatic set devices based on what accelerator decided

        self.imagen.to(self.device)
        self.to(self.device)

        # checkpointing

        assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
        self.checkpoint_path = checkpoint_path
        self.checkpoint_every = checkpoint_every
        self.max_checkpoints_keep = max_checkpoints_keep

        self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main

        if exists(checkpoint_path) and self.can_checkpoint:
            bucket = url_to_bucket(checkpoint_path)

            if not self.fs.exists(bucket):
                self.fs.mkdir(bucket)

            self.load_from_checkpoint_folder()

        # only allowing training for unet

        self.only_train_unet_number = only_train_unet_number
        self.validate_and_set_unet_being_trained(only_train_unet_number)

    # computed values

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    @property
    def unwrapped_unet(self):
        return self.accelerator.unwrap_model(self.unet_being_trained)

    # optimizer helper functions

    def get_lr(self, unet_number):
        self.validate_unet_number(unet_number)
        unet_index = unet_number - 1

        optim = getattr(self, f'optim{unet_index}')

        return optim.param_groups[0]['lr']

    # function for allowing only one unet from being trained at a time

    def validate_and_set_unet_being_trained(self, unet_number = None):
        if exists(unet_number):
            self.validate_unet_number(unet_number)

        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'

        self.only_train_unet_number = unet_number
        self.imagen.only_train_unet_number = unet_number

        if not exists(unet_number):
            return

        self.wrap_unet(unet_number)

    def wrap_unet(self, unet_number):
        if hasattr(self, 'one_unet_wrapped'):
            return

        unet = self.imagen.get_unet(unet_number)
        self.unet_being_trained = self.accelerator.prepare(unet)
        unet_index = unet_number - 1

        optimizer = getattr(self, f'optim{unet_index}')
        scheduler = getattr(self, f'scheduler{unet_index}')

        optimizer = self.accelerator.prepare(optimizer)

        if exists(scheduler):
            scheduler = self.accelerator.prepare(scheduler)

        setattr(self, f'optim{unet_index}', optimizer)
        setattr(self, f'scheduler{unet_index}', scheduler)

        self.one_unet_wrapped = True

    # hacking accelerator due to not having separate gradscaler per optimizer

    def set_accelerator_scaler(self, unet_number):
        unet_number = self.validate_unet_number(unet_number)
        scaler = getattr(self, f'scaler{unet_number - 1}')

        self.accelerator.scaler = scaler
        for optimizer in self.accelerator._optimizers:
            optimizer.scaler = scaler

    # helper print

    def print(self, msg):
        if not self.is_main:
            return

        if not self.verbose:
            return

        return self.accelerator.print(msg)

    # validating the unet number

    def validate_unet_number(self, unet_number = None):
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
        return unet_number

    # number of training steps taken

    def num_steps_taken(self, unet_number = None):
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        return self.steps[unet_number - 1].item()

    def print_untrained_unets(self):
        print_final_error = False

        for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
            if steps > 0 or isinstance(unet, NullUnet):
                continue

            self.print(f'unet {ind + 1} has not been trained')
            print_final_error = True

        if print_final_error:
            self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')

    # data related functions

    def add_train_dataloader(self, dl = None):
        if not exists(dl):
            return

        assert not exists(self.train_dl), 'training dataloader was already added'
        self.train_dl = self.accelerator.prepare(dl)

    def add_valid_dataloader(self, dl):
        if not exists(dl):
            return

        assert not exists(self.valid_dl), 'validation dataloader was already added'
        self.valid_dl = self.accelerator.prepare(dl)

    def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        assert not exists(self.train_dl), 'training dataloader was already added'

        valid_ds = None
        if self.split_valid_from_train:
            train_size = int((1 - self.split_valid_fraction) * len(ds))
            valid_size = len(ds) - train_size

            ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
            self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')

        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.train_dl = self.accelerator.prepare(dl)

        if not self.split_valid_from_train:
            return

        self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)

    def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        assert not exists(self.valid_dl), 'validation dataloader was already added'

        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.valid_dl = self.accelerator.prepare(dl)

    def create_train_iter(self):
        assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'

        if exists(self.train_dl_iter):
            return

        self.train_dl_iter = cycle(self.train_dl)

    def create_valid_iter(self):
        assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'

        if exists(self.valid_dl_iter):
            return

        self.valid_dl_iter = cycle(self.valid_dl)

    def train_step(self, unet_number = None, **kwargs):
        self.create_train_iter()
        loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
        self.update(unet_number = unet_number)
        return loss

    @torch.no_grad()
    @eval_decorator
    def valid_step(self, **kwargs):
        self.create_valid_iter()

        context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext

        with context():
            loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
        return loss

    def step_with_dl_iter(self, dl_iter, **kwargs):
        dl_tuple_output = cast_tuple(next(dl_iter))
        model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
        loss = self.forward(**{**kwargs, **model_input})
        return loss

    # checkpointing functions

    @property
    def all_checkpoints_sorted(self):
        glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
        checkpoints = self.fs.glob(glob_pattern)
        sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
        return sorted_checkpoints

    def load_from_checkpoint_folder(self, last_total_steps = -1):
        if last_total_steps != -1:
            filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
            self.load(filepath)
            return

        sorted_checkpoints = self.all_checkpoints_sorted

        if len(sorted_checkpoints) == 0:
            self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
            return

        last_checkpoint = sorted_checkpoints[0]
        self.load(last_checkpoint)

    def save_to_checkpoint_folder(self):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        total_steps = int(self.steps.sum().item())
        filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')

        self.save(filepath)

        if self.max_checkpoints_keep <= 0:
            return

        sorted_checkpoints = self.all_checkpoints_sorted
        checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]

        for checkpoint in checkpoints_to_discard:
            self.fs.rm(checkpoint)

    # saving and loading functions

    def save(
        self,
        path,
        overwrite = True,
        without_optim_and_sched = False,
        **kwargs
    ):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        fs = self.fs

        assert not (fs.exists(path) and not overwrite)

        self.reset_ema_unets_all_one_device()

        save_obj = dict(
            model = self.imagen.state_dict(),
            version = __version__,
            steps = self.steps.cpu(),
            **kwargs
        )

        save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()

        for ind in save_optim_and_sched_iter:
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            if exists(scheduler):
                save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}

            if exists(warmup_scheduler):
                save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}

            save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}

        if self.use_ema:
            save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}

        # determine if imagen config is available

        if hasattr(self.imagen, '_config'):
            self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""')

            save_obj = {
                **save_obj,
                'imagen_type': 'elucidated' if self.is_elucidated else 'original',
                'imagen_params': self.imagen._config
            }

        #save to path

        with fs.open(path, 'wb') as f:
            torch.save(save_obj, f)

        self.print(f'checkpoint saved to {path}')

    def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
        fs = self.fs

        if noop_if_not_exist and not fs.exists(path):
            self.print(f'trainer checkpoint not found at {str(path)}')
            return

        assert fs.exists(path), f'{path} does not exist'

        self.reset_ema_unets_all_one_device()

        # to avoid extra GPU memory usage in main process when using Accelerate

        with fs.open(path) as f:
            loaded_obj = torch.load(f, map_location='cpu')

        if version.parse(__version__) != version.parse(loaded_obj['version']):
            self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')

        try:
            self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
        except RuntimeError:
            print("Failed loading state dict. Trying partial load")
            self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
                                                      loaded_obj['model']))

        if only_model:
            return loaded_obj

        self.steps.copy_(loaded_obj['steps'])

        for ind in range(0, self.num_unets):
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            if exists(scheduler) and scheduler_key in loaded_obj:
                scheduler.load_state_dict(loaded_obj[scheduler_key])

            if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
                warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])

            if exists(optimizer):
                try:
                    optimizer.load_state_dict(loaded_obj[optimizer_key])
                    scaler.load_state_dict(loaded_obj[scaler_key])
                except:
                    self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')

        if self.use_ema:
            assert 'ema' in loaded_obj
            try:
                self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
            except RuntimeError:
                print("Failed loading state dict. Trying partial load")
                self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
                                                             loaded_obj['ema']))

        self.print(f'checkpoint loaded from {path}')
        return loaded_obj

    # managing ema unets and their devices

    @property
    def unets(self):
        return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

    def get_ema_unet(self, unet_number = None):
        if not self.use_ema:
            return

        unet_number = self.validate_unet_number(unet_number)
        index = unet_number - 1

        if isinstance(self.unets, nn.ModuleList):
            unets_list = [unet for unet in self.ema_unets]
            delattr(self, 'ema_unets')
            self.ema_unets = unets_list

        if index != self.ema_unet_being_trained_index:
            for unet_index, unet in enumerate(self.ema_unets):
                unet.to(self.device if unet_index == index else 'cpu')

        self.ema_unet_being_trained_index = index
        return self.ema_unets[index]

    def reset_ema_unets_all_one_device(self, device = None):
        if not self.use_ema:
            return

        device = default(device, self.device)
        self.ema_unets = nn.ModuleList([*self.ema_unets])
        self.ema_unets.to(device)

        self.ema_unet_being_trained_index = -1

    @torch.no_grad()
    @contextmanager
    def use_ema_unets(self):
        if not self.use_ema:
            output = yield
            return output

        self.reset_ema_unets_all_one_device()
        self.imagen.reset_unets_all_one_device()

        self.unets.eval()

        trainable_unets = self.imagen.unets
        self.imagen.unets = self.unets                  # swap in exponential moving averaged unets for sampling

        output = yield

        self.imagen.unets = trainable_unets             # restore original training unets

        # cast the ema_model unets back to original device
        for ema in self.ema_unets:
            ema.restore_ema_model_device()

        return output

    def print_unet_devices(self):
        self.print('unet devices:')
        for i, unet in enumerate(self.imagen.unets):
            device = next(unet.parameters()).device
            self.print(f'\tunet {i}: {device}')

        if not self.use_ema:
            return

        self.print('\nema unet devices:')
        for i, ema_unet in enumerate(self.ema_unets):
            device = next(ema_unet.parameters()).device
            self.print(f'\tema unet {i}: {device}')

    # overriding state dict functions

    def state_dict(self, *args, **kwargs):
        self.reset_ema_unets_all_one_device()
        return super().state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        self.reset_ema_unets_all_one_device()
        return super().load_state_dict(*args, **kwargs)

    # encoding text functions

    def encode_text(self, text, **kwargs):
        return self.imagen.encode_text(text, **kwargs)

    # forwarding functions and gradient step updates

    def update(self, unet_number = None):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        index = unet_number - 1
        unet = self.unet_being_trained

        optimizer = getattr(self, f'optim{index}')
        scaler = getattr(self, f'scaler{index}')
        scheduler = getattr(self, f'scheduler{index}')
        warmup_scheduler = getattr(self, f'warmup{index}')

        # set the grad scaler on the accelerator, since we are managing one per u-net

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)

        optimizer.step()
        optimizer.zero_grad()

        if self.use_ema:
            ema_unet = self.get_ema_unet(unet_number)
            ema_unet.update()

        # scheduler, if needed

        maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()

        with maybe_warmup_context:
            if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
                scheduler.step()

        self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))

        if not exists(self.checkpoint_path):
            return

        total_steps = int(self.steps.sum().item())

        if total_steps % self.checkpoint_every:
            return

        self.save_to_checkpoint_folder()

    @torch.no_grad()
    @cast_torch_tensor
    @imagen_sample_in_chunks
    def sample(self, *args, **kwargs):
        context = nullcontext if  kwargs.pop('use_non_ema', False) else self.use_ema_unets

        self.print_untrained_unets()        
        
        if not self.is_main:
            kwargs['use_tqdm'] = False

        with context():
            output = self.imagen.sample(*args, device = self.device, **kwargs)

        return output

    @partial(cast_torch_tensor, cast_fp16 = True)
    def forward(
        self,
        *args,
        unet_number = None,
        max_batch_size = None,
        **kwargs
    ):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'

        total_loss = 0.
        
        
        # + for debug
        if self.CKeys['Debug_TrainerPack']==1:
            print("In Trainer:Forward, check inputs:")
            print('args: ', len(args))
            print('args in:',args[0].shape)
            print('kwargs: ', kwargs.keys())
        for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
            # + for debug
            if self.CKeys['Debug_TrainerPack']==1:
                print("after chunks,...")
                print('chun_frac: ', chunk_size_frac)
                print('chun_args: ', chunked_args)
                print('chun_kwargs: ', chunked_kwargs)
            
            with self.accelerator.autocast():
                loss = self.model(
                    *chunked_args, 
                    unet_number = unet_number, 
                    **chunked_kwargs
                )
                loss = loss * chunk_size_frac
            
            # + for debug
            if self.CKeys['Debug_TrainerPack']==1:
                print('part chun loss: ', loss)

            total_loss += loss#.item()

            if self.training:
                self.accelerator.backward(loss)

        return total_loss
    
class ImagenTrainer_Old(nn.Module):
    locked = False

    def __init__(
        self,
        #imagen = None,
        model = None,
        
        imagen_checkpoint_path = None,
        use_ema = True,
        lr = 1e-4,
        eps = 1e-8,
        beta1 = 0.9,
        beta2 = 0.99,
        max_grad_norm = None,
        group_wd_params = True,
        warmup_steps = None,
        cosine_decay_max_steps = None,
        only_train_unet_number = None,
        fp16 = False,
        precision = None,
        split_batches = True,
        dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
        verbose = True,
        split_valid_fraction = 0.025,
        split_valid_from_train = False,
        split_random_seed = 42,
        checkpoint_path = None,
        checkpoint_every = None,
        checkpoint_fs = None,
        fs_kwargs: dict = None,
        max_checkpoints_keep = 20,
        **kwargs
    ):
        super().__init__()
        assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
        assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'

        # determine filesystem, using fsspec, for saving to local filesystem or cloud

        self.fs = checkpoint_fs

        if not exists(self.fs):
            fs_kwargs = default(fs_kwargs, {})
            self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
        
        # # -----------------------------------
        # # from MJB
        # assert isinstance(model.imagen, (ProteinDesigner_B))
        
        ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)

         
        self.imagen = model.imagen
       
        

        self.model=model
        self.is_elucidated = self.model.is_elucidated 
        # create accelerator instance

        accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)

        assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
        accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')

        self.accelerator = Accelerator(**{
            'split_batches': split_batches,
            'mixed_precision': accelerator_mixed_precision,
            'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)]
        , **accelerate_kwargs})

        ImagenTrainer.locked = self.is_distributed

        # cast data to fp16 at training time if needed

        self.cast_half_at_training = accelerator_mixed_precision == 'fp16'

        # grad scaler must be managed outside of accelerator

        grad_scaler_enabled = fp16
   
        self.num_unets = len(self.imagen.unets)

        self.use_ema = use_ema and self.is_main
        self.ema_unets = nn.ModuleList([])

        # keep track of what unet is being trained on
        # only going to allow 1 unet training at a time

        self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on

        # data related functions

        self.train_dl_iter = None
        self.train_dl = None

        self.valid_dl_iter = None
        self.valid_dl = None

        self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names

        # auto splitting validation from training, if dataset is passed in

        self.split_valid_from_train = split_valid_from_train

        assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
        self.split_valid_fraction = split_valid_fraction
        self.split_random_seed = split_random_seed

        # be able to finely customize learning rate, weight decay
        # per unet

        lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))

        for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
            optimizer = Adam(
                unet.parameters(),
                lr = unet_lr,
                eps = unet_eps,
                betas = (beta1, beta2),
                **kwargs
            )

            if self.use_ema:
                self.ema_unets.append(EMA(unet, **ema_kwargs))

            scaler = GradScaler(enabled = grad_scaler_enabled)

            scheduler = warmup_scheduler = None

            if exists(unet_cosine_decay_max_steps):
                scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)

            if exists(unet_warmup_steps):
                warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps)

                if not exists(scheduler):
                    scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)

            # set on object

            setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
            setattr(self, f'scaler{ind}', scaler)
            setattr(self, f'scheduler{ind}', scheduler)
            setattr(self, f'warmup{ind}', warmup_scheduler)

        # gradient clipping if needed

        self.max_grad_norm = max_grad_norm

        # step tracker and misc

        self.register_buffer('steps', torch.tensor([0] * self.num_unets))

        self.verbose = verbose

        # automatic set devices based on what accelerator decided

        self.imagen.to(self.device)
        self.to(self.device)

        # checkpointing

        assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
        self.checkpoint_path = checkpoint_path
        self.checkpoint_every = checkpoint_every
        self.max_checkpoints_keep = max_checkpoints_keep

        self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main

        if exists(checkpoint_path) and self.can_checkpoint:
            bucket = url_to_bucket(checkpoint_path)

            if not self.fs.exists(bucket):
                self.fs.mkdir(bucket)

            self.load_from_checkpoint_folder()

        # only allowing training for unet

        self.only_train_unet_number = only_train_unet_number
        self.validate_and_set_unet_being_trained(only_train_unet_number)

    # computed values

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    @property
    def unwrapped_unet(self):
        return self.accelerator.unwrap_model(self.unet_being_trained)

    # optimizer helper functions

    def get_lr(self, unet_number):
        self.validate_unet_number(unet_number)
        unet_index = unet_number - 1

        optim = getattr(self, f'optim{unet_index}')

        return optim.param_groups[0]['lr']

    # function for allowing only one unet from being trained at a time

    def validate_and_set_unet_being_trained(self, unet_number = None):
        if exists(unet_number):
            self.validate_unet_number(unet_number)

        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'

        self.only_train_unet_number = unet_number
        self.imagen.only_train_unet_number = unet_number

        if not exists(unet_number):
            return

        self.wrap_unet(unet_number)

    def wrap_unet(self, unet_number):
        if hasattr(self, 'one_unet_wrapped'):
            return

        unet = self.imagen.get_unet(unet_number)
        self.unet_being_trained = self.accelerator.prepare(unet)
        unet_index = unet_number - 1

        optimizer = getattr(self, f'optim{unet_index}')
        scheduler = getattr(self, f'scheduler{unet_index}')

        optimizer = self.accelerator.prepare(optimizer)

        if exists(scheduler):
            scheduler = self.accelerator.prepare(scheduler)

        setattr(self, f'optim{unet_index}', optimizer)
        setattr(self, f'scheduler{unet_index}', scheduler)

        self.one_unet_wrapped = True

    # hacking accelerator due to not having separate gradscaler per optimizer

    def set_accelerator_scaler(self, unet_number):
        unet_number = self.validate_unet_number(unet_number)
        scaler = getattr(self, f'scaler{unet_number - 1}')

        self.accelerator.scaler = scaler
        for optimizer in self.accelerator._optimizers:
            optimizer.scaler = scaler

    # helper print

    def print(self, msg):
        if not self.is_main:
            return

        if not self.verbose:
            return

        return self.accelerator.print(msg)

    # validating the unet number

    def validate_unet_number(self, unet_number = None):
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
        return unet_number

    # number of training steps taken

    def num_steps_taken(self, unet_number = None):
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        return self.steps[unet_number - 1].item()

    def print_untrained_unets(self):
        print_final_error = False

        for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
            if steps > 0 or isinstance(unet, NullUnet):
                continue

            self.print(f'unet {ind + 1} has not been trained')
            print_final_error = True

        if print_final_error:
            self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')

    # data related functions

    def add_train_dataloader(self, dl = None):
        if not exists(dl):
            return

        assert not exists(self.train_dl), 'training dataloader was already added'
        self.train_dl = self.accelerator.prepare(dl)

    def add_valid_dataloader(self, dl):
        if not exists(dl):
            return

        assert not exists(self.valid_dl), 'validation dataloader was already added'
        self.valid_dl = self.accelerator.prepare(dl)

    def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        assert not exists(self.train_dl), 'training dataloader was already added'

        valid_ds = None
        if self.split_valid_from_train:
            train_size = int((1 - self.split_valid_fraction) * len(ds))
            valid_size = len(ds) - train_size

            ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
            self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')

        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.train_dl = self.accelerator.prepare(dl)

        if not self.split_valid_from_train:
            return

        self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)

    def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        assert not exists(self.valid_dl), 'validation dataloader was already added'

        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.valid_dl = self.accelerator.prepare(dl)

    def create_train_iter(self):
        assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'

        if exists(self.train_dl_iter):
            return

        self.train_dl_iter = cycle(self.train_dl)

    def create_valid_iter(self):
        assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'

        if exists(self.valid_dl_iter):
            return

        self.valid_dl_iter = cycle(self.valid_dl)

    def train_step(self, unet_number = None, **kwargs):
        self.create_train_iter()
        loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
        self.update(unet_number = unet_number)
        return loss

    @torch.no_grad()
    @eval_decorator
    def valid_step(self, **kwargs):
        self.create_valid_iter()

        context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext

        with context():
            loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
        return loss

    def step_with_dl_iter(self, dl_iter, **kwargs):
        dl_tuple_output = cast_tuple(next(dl_iter))
        model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
        loss = self.forward(**{**kwargs, **model_input})
        return loss

    # checkpointing functions

    @property
    def all_checkpoints_sorted(self):
        glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
        checkpoints = self.fs.glob(glob_pattern)
        sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
        return sorted_checkpoints

    def load_from_checkpoint_folder(self, last_total_steps = -1):
        if last_total_steps != -1:
            filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
            self.load(filepath)
            return

        sorted_checkpoints = self.all_checkpoints_sorted

        if len(sorted_checkpoints) == 0:
            self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
            return

        last_checkpoint = sorted_checkpoints[0]
        self.load(last_checkpoint)

    def save_to_checkpoint_folder(self):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        total_steps = int(self.steps.sum().item())
        filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')

        self.save(filepath)

        if self.max_checkpoints_keep <= 0:
            return

        sorted_checkpoints = self.all_checkpoints_sorted
        checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]

        for checkpoint in checkpoints_to_discard:
            self.fs.rm(checkpoint)

    # saving and loading functions

    def save(
        self,
        path,
        overwrite = True,
        without_optim_and_sched = False,
        **kwargs
    ):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        fs = self.fs

        assert not (fs.exists(path) and not overwrite)

        self.reset_ema_unets_all_one_device()

        save_obj = dict(
            model = self.imagen.state_dict(),
            version = __version__,
            steps = self.steps.cpu(),
            **kwargs
        )

        save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()

        for ind in save_optim_and_sched_iter:
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            if exists(scheduler):
                save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}

            if exists(warmup_scheduler):
                save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}

            save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}

        if self.use_ema:
            save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}

        # determine if imagen config is available

        if hasattr(self.imagen, '_config'):
            self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""')

            save_obj = {
                **save_obj,
                'imagen_type': 'elucidated' if self.is_elucidated else 'original',
                'imagen_params': self.imagen._config
            }

        #save to path

        with fs.open(path, 'wb') as f:
            torch.save(save_obj, f)

        self.print(f'checkpoint saved to {path}')

    def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
        fs = self.fs

        if noop_if_not_exist and not fs.exists(path):
            self.print(f'trainer checkpoint not found at {str(path)}')
            return

        assert fs.exists(path), f'{path} does not exist'

        self.reset_ema_unets_all_one_device()

        # to avoid extra GPU memory usage in main process when using Accelerate

        with fs.open(path) as f:
            loaded_obj = torch.load(f, map_location='cpu')

        if version.parse(__version__) != version.parse(loaded_obj['version']):
            self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')

        try:
            self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
        except RuntimeError:
            print("Failed loading state dict. Trying partial load")
            self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
                                                      loaded_obj['model']))

        if only_model:
            return loaded_obj

        self.steps.copy_(loaded_obj['steps'])

        for ind in range(0, self.num_unets):
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            if exists(scheduler) and scheduler_key in loaded_obj:
                scheduler.load_state_dict(loaded_obj[scheduler_key])

            if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
                warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])

            if exists(optimizer):
                try:
                    optimizer.load_state_dict(loaded_obj[optimizer_key])
                    scaler.load_state_dict(loaded_obj[scaler_key])
                except:
                    self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')

        if self.use_ema:
            assert 'ema' in loaded_obj
            try:
                self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
            except RuntimeError:
                print("Failed loading state dict. Trying partial load")
                self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
                                                             loaded_obj['ema']))

        self.print(f'checkpoint loaded from {path}')
        return loaded_obj

    # managing ema unets and their devices

    @property
    def unets(self):
        return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

    def get_ema_unet(self, unet_number = None):
        if not self.use_ema:
            return

        unet_number = self.validate_unet_number(unet_number)
        index = unet_number - 1

        if isinstance(self.unets, nn.ModuleList):
            unets_list = [unet for unet in self.ema_unets]
            delattr(self, 'ema_unets')
            self.ema_unets = unets_list

        if index != self.ema_unet_being_trained_index:
            for unet_index, unet in enumerate(self.ema_unets):
                unet.to(self.device if unet_index == index else 'cpu')

        self.ema_unet_being_trained_index = index
        return self.ema_unets[index]

    def reset_ema_unets_all_one_device(self, device = None):
        if not self.use_ema:
            return

        device = default(device, self.device)
        self.ema_unets = nn.ModuleList([*self.ema_unets])
        self.ema_unets.to(device)

        self.ema_unet_being_trained_index = -1

    @torch.no_grad()
    @contextmanager
    def use_ema_unets(self):
        if not self.use_ema:
            output = yield
            return output

        self.reset_ema_unets_all_one_device()
        self.imagen.reset_unets_all_one_device()

        self.unets.eval()

        trainable_unets = self.imagen.unets
        self.imagen.unets = self.unets                  # swap in exponential moving averaged unets for sampling

        output = yield

        self.imagen.unets = trainable_unets             # restore original training unets

        # cast the ema_model unets back to original device
        for ema in self.ema_unets:
            ema.restore_ema_model_device()

        return output

    def print_unet_devices(self):
        self.print('unet devices:')
        for i, unet in enumerate(self.imagen.unets):
            device = next(unet.parameters()).device
            self.print(f'\tunet {i}: {device}')

        if not self.use_ema:
            return

        self.print('\nema unet devices:')
        for i, ema_unet in enumerate(self.ema_unets):
            device = next(ema_unet.parameters()).device
            self.print(f'\tema unet {i}: {device}')

    # overriding state dict functions

    def state_dict(self, *args, **kwargs):
        self.reset_ema_unets_all_one_device()
        return super().state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        self.reset_ema_unets_all_one_device()
        return super().load_state_dict(*args, **kwargs)

    # encoding text functions

    def encode_text(self, text, **kwargs):
        return self.imagen.encode_text(text, **kwargs)

    # forwarding functions and gradient step updates

    def update(self, unet_number = None):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        index = unet_number - 1
        unet = self.unet_being_trained

        optimizer = getattr(self, f'optim{index}')
        scaler = getattr(self, f'scaler{index}')
        scheduler = getattr(self, f'scheduler{index}')
        warmup_scheduler = getattr(self, f'warmup{index}')

        # set the grad scaler on the accelerator, since we are managing one per u-net

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)

        optimizer.step()
        optimizer.zero_grad()

        if self.use_ema:
            ema_unet = self.get_ema_unet(unet_number)
            ema_unet.update()

        # scheduler, if needed

        maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()

        with maybe_warmup_context:
            if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
                scheduler.step()

        self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))

        if not exists(self.checkpoint_path):
            return

        total_steps = int(self.steps.sum().item())

        if total_steps % self.checkpoint_every:
            return

        self.save_to_checkpoint_folder()

    @torch.no_grad()
    @cast_torch_tensor
    @imagen_sample_in_chunks
    def sample(self, *args, **kwargs):
        context = nullcontext if  kwargs.pop('use_non_ema', False) else self.use_ema_unets

        self.print_untrained_unets()        
        
        if not self.is_main:
            kwargs['use_tqdm'] = False

        with context():
            output = self.imagen.sample(*args, device = self.device, **kwargs)

        return output

    @partial(cast_torch_tensor, cast_fp16 = True)
    def forward(
        self,
        *args,
        unet_number = None,
        max_batch_size = None,
        **kwargs
    ):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'

        total_loss = 0.
        
        
        # + for debug
        print('args: ', len(args))
        print('args in:',args[0].shape)
        print('kwargs: ', kwargs.keys())
        for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
            # + for debug
            print('chun_frac: ', chunk_size_frac)
            print('chun_args: ', chunked_args)
            print('chun_kwargs: ', chunked_kwargs)
            
            with self.accelerator.autocast():
                loss = self.model(*chunked_args, unet_number = unet_number, **chunked_kwargs)
                loss = loss * chunk_size_frac
            
            print('loss: ', loss)

            total_loss += loss#.item()

            if self.training:
                self.accelerator.backward(loss)

        print('I am here')
        return total_loss

    
def write_fasta (sequence, filename_out):
    
    with open (filename_out, mode ='w') as f:
        f.write (f'>{filename_out}\n')
        f.write (f'{sequence}')


#
def sample_sequence_FromModelB (
    model,
    X=None, #this is the target conventionally when using text embd
    flag=0,
    cond_scales=1.,
    foldproteins=False,
    X_string=None,
    x_data=None,  
    skip_steps=0,
    inpaint_images = None,
    inpaint_masks = None,
    inpaint_resample_times = None,
    init_images = None,
    num_cycle=16,
    # ++++++++++++++++++++++++
    ynormfac=1,
    train_unet_number=1,
    tokenizer_X=None,
    Xnormfac=1.,
    max_length=1.,
    prefix=None,
    tokenizer_y=None,
               ):
    steps=0
    e=flag


    

    #num_samples = min (num_samples,y_train_batch.shape[0] )
    if X!=None:
        print (f"Producing {len(X)} samples...from text conditioning X...")
        lenn_val=len(X)
    if X_string!=None:
        lenn_val=len(X_string)
        print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...")
    
    if x_data!=None:
        print (f"Producing {len(x_data)} samples...from image conditingig x_data  ...")
        lenn_val=len(x_data)
        print (x_data)
        
    print ('Device: ', device)


    for iisample in range (lenn_val):
        X_cond=None  
        if X_string==None and X != None: #only do if X provided
            X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
        if X_string !=None:
            X = tokenizer_X.texts_to_sequences(X_string[iisample])
            X= sequence.pad_sequences(X,  maxlen=max_length, padding='post', truncating='post')  
            X=np.array(X)
            X_cond=torch.from_numpy(X).float()/Xnormfac
            print ('Tokenized and processed: ', X_cond)
        
        print ("X_cond=", X_cond)
        
        result=model.sample ( 
            x=X_cond,
            stop_at_unet_number=train_unet_number ,
            cond_scale=cond_scales ,
            x_data=x_data, skip_steps=skip_steps,
            inpaint_images = inpaint_images,
            inpaint_masks = inpaint_masks,
            inpaint_resample_times = inpaint_resample_times,
            init_images = init_images,device=device,
            # ++++++++++++++++++++++++++
            tokenizer_X=tokenizer_X,
            Xnormfac=Xnormfac,
            max_length=max_length,
                            )
        result=torch.round(result*ynormfac)
        
        plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted')
        #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}')
        plt.legend()

        outname = prefix+ f"sampled_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg"
        #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}")
        plt.savefig(outname, dpi=200)
        plt.show ()

        to_rev=result[:,0,:] 
        to_rev=to_rev.long().cpu().detach().numpy()
        print (to_rev.shape)
        y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

        for iii in range (len(y_data_reversed)):
            y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
        
        ### reverse second structure input....
        if X_cond != None:
            X_cond=torch.round(X_cond*Xnormfac)

            to_rev=X_cond[:,:] 
            to_rev=to_rev.long().cpu().detach().numpy()
            print (to_rev.shape)
            X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)

            for iii in range (len(y_data_reversed)):
                X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
        if x_data !=None:
            X_data_reversed=x_data #is already in sequence fromat..
               

        print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample])
        if foldproteins:
            
            if X_cond != None:
                xbc=X_cond[iisample,:].cpu().detach().numpy()
                out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}'
            if x_data !=None:
                #xbc=x_data[iisample] 
                out_nam=x_data[iisample] 
             
            
            tempname='temp'
            pdb_file=foldandsavePDB (
                sequence=y_data_reversed[0], 
                filename_out=tempname, 
                num_cycle=num_cycle, 
                flag=flag,
                # +++++++++++++++++++
                prefix=prefix
            )

            out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta'

            write_fasta (y_data_reversed[0], out_nam_fasta)            
            
        
            out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
            # print('Debug 1: out: ', out_nam)
            # print('Debug 2: in: ', pdb_file)
            shutil.copy (pdb_file, out_nam) #source, dest
            # cmd_line = 'cp ' + pdb_file + ' ' + out_nam
            # print(cmd_line)
            # os.popen(cmd_line)
            # print('Debug 3')
            pdb_file=out_nam
            
            
            
            
            
            print (f"Properly named PDB file produced: {pdb_file}")
            #flag=1000
            view=show_pdb(pdb_file=pdb_file, flag=flag,
                          show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color)
            view.show()


        steps=steps+1
        
        return pdb_file 
    
def sample_loop_FromModelB (model,
                train_loader,
                cond_scales=[7.5], #list of cond scales - each sampled...
                num_samples=2, #how many samples produced every time tested.....
                timesteps=100,
                 flag=0,foldproteins=False,
                 use_text_embedd=True,skip_steps=0,
                 # +++++++++++++++++++
                 train_unet_number=1,
                 ynormfac=1,
                 prefix=None,
                 tokenizer_y=None,
                 Xnormfac=1,
                 tokenizer_X=None,
                 
               ):
    steps=0
    e=flag
    for item  in train_loader:

            X_train_batch= item[0].to(device)
            y_train_batch=item[1].to(device)

            GT=y_train_batch.cpu().detach() 
                    
            GT= GT.unsqueeze(1)
            num_samples = min (num_samples,y_train_batch.shape[0] )
            print (f"Producing {num_samples} samples...")
            
            print ('X_train_batch shape: ', X_train_batch.shape)

            for iisample in range (len (cond_scales)):
                
                if use_text_embedd:
                    result=model.sample (x= X_train_batch,stop_at_unet_number=train_unet_number ,
                                         cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps)
                else:
                    result=model.sample (x= None, x_data_tokenized= X_train_batch,
                                         stop_at_unet_number=train_unet_number ,
                                         cond_scale=cond_scales[iisample],device=device,skip_steps=skip_steps)
                    
                result=torch.round(result*ynormfac)
                GT=torch.round (GT*ynormfac)

                for samples in range  (num_samples):
                    print ("sample ", samples, "out of ", num_samples)
                    
                    plt.plot (result[samples,0,:].cpu().detach().numpy(),label= f'Predicted')
                    plt.plot (GT[samples,0,:],label= f'GT {0}')
                    plt.legend()

                    outname = prefix+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                   
                    plt.savefig(outname, dpi=200)
                    plt.show ()
                    
                    #reverse y sequence
                    to_rev=result[:,0,:]
                    to_rev=to_rev.long().cpu().detach().numpy()
                    
                    y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

                    for iii in range (len(y_data_reversed)):
                        y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
                        
                    #reverse GT_y sequence
                    to_rev=GT[:,0,:]
                    to_rev=to_rev.long().cpu().detach().numpy()
                    
                    GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

                    for iii in range (len(y_data_reversed)):
                        GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "")
                    
                    ### reverse second structure input....
                    to_rev=torch.round (X_train_batch[:,:]*Xnormfac)
                    to_rev=to_rev.long().cpu().detach().numpy()
                   
                    X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)

                    for iii in range (len(y_data_reversed)):
                        X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")

                    print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, predicted sequence: ", y_data_reversed[samples])
                    print (f"Ground truth: {GT_y_data_reversed[samples]}")
                   
                    if foldproteins:
                        xbc=X_train_batch[samples,:].cpu().detach().numpy()
                        out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})
                        tempname='temp'
                        pdb_file=foldandsavePDB (
                            sequence=y_data_reversed[samples], 
                            filename_out=tempname, 
                            num_cycle=16, flag=flag,
                            # +++++++++++++++++++
                            prefix=prefix
                        )
                        
                        #out_nam=f'{prefix}{out_nam}.pdb'
                        out_nam=f'{prefix}{X_data_reversed[samples]}.pdb'
                        print (f'Original PDB: {pdb_file} OUT: {out_nam}')
                        shutil.copy (pdb_file, out_nam) #source, dest
                        pdb_file=out_nam
                        print (f"Properly named PDB file produced: {pdb_file}")
                        
                        view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains,  show_mainchains=show_mainchains, color=color)
                        view.show()

                    steps=steps+1
            if steps>num_samples:
                break
# ++
def sample_sequence_omegafold_pLM_ModelB (
    model,
    X=None, #this is the target conventionally when using text embd
    flag=0,
    cond_scales=1.,
    foldproteins=False,
    X_string=None,
    x_data=None,  
    skip_steps=0,
    inpaint_images = None,
    inpaint_masks = None,
    inpaint_resample_times = None,
    init_images = None,
    num_cycle=16,
    # ++++++++++++++++++++++++
    ynormfac=1,
    train_unet_number=1,
    tokenizer_X=None,
    Xnormfac=1.,
    max_length=1.,
    prefix=None,
    tokenizer_y=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
    # ++
    pLM_Model=None, # pLM_Model,
    pLM_Model_Name=None, # pLM_Model_Name,
    image_channels=None, # image_channels,
    pLM_alphabet=None, # esm_alphabet,
):
    
    # steps=0
    # e=flag

    #num_samples = min (num_samples,y_train_batch.shape[0] )
    if X!=None:
        print (f"Producing {len(X)} samples...from text conditioning X...")
        lenn_val=len(X)
    if X_string!=None:
        lenn_val=len(X_string)
        print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...")
    
    if x_data!=None:
        print (f"Producing {len(x_data)} samples...from image conditingig x_data  ...")
        lenn_val=len(x_data)
        print (x_data)
        
    print ('Device: ', device)
    
    pdb_file_list=[]
    fasta_file_list=[]

    # + for debug
    print('tot ', lenn_val)
    for iisample in range (lenn_val):
        print("Working on ", iisample)
        X_cond=None  # this is for text-conditioning
        if X_string==None and X != None: #only do if X provided
            X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
        if X_string !=None:
            XX = tokenizer_X.texts_to_sequences(X_string[iisample])
            XX= sequence.pad_sequences(XX,  maxlen=max_length, padding='post', truncating='post')  
            XX=np.array(XX)
            X_cond=torch.from_numpy(XX).float()/Xnormfac
            print ('Tokenized and processed: ', X_cond)
        
        print ("X_cond=", X_cond)
        
        # # --
        # result=model.sample ( 
        #     x=X_cond,
        #     stop_at_unet_number=train_unet_number ,
        #     cond_scale=cond_scales ,
        #     x_data=x_data[iisample], 
        #     # ++
        #     x_data_tokenized=
        #     skip_steps=skip_steps,
        #     inpaint_images = inpaint_images,
        #     inpaint_masks = inpaint_masks,
        #     inpaint_resample_times = inpaint_resample_times,
        #     init_images = init_images,device=device,
        #     # ++++++++++++++++++++++++++
        #     tokenizer_X=tokenizer_X,
        #     Xnormfac=Xnormfac,
        #     max_length=max_length,
        #                     )
        # ++
        # use cond_image as the conditioning, via x_data_tokenized channel
        
        # -----------------------------------------------------------------
        # for below, two branches are all for cond_img, not for text_cond
        if tokenizer_X!=None:
            # for SecStr+ModelB
            result_embedding=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=cond_scales ,
                x_data=x_data[iisample],  # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim
                # ++
                x_data_tokenized=None,
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,device=device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X,
                Xnormfac=Xnormfac,
                max_length=max_length,
            )
        else:
            # for ForcPath+ModelB:
            # for model.sample() here using x_data_tokenized channel
            #
            x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac)
            x_data_tokenized=x_data_tokenized.to(torch.float)
            # here, only one input list is read in
            x_data_tokenized=x_data_tokenized.unsqueeze(0) # [batch=1, seq_len]
            # leave channel expansion for the self.sample() to handle
            
            # + for debug:
            if CKeys['Debug_TrainerPack']==3:
                print("x_data_tokenized dim: ", x_data_tokenized.shape)
                print("x_data_tokenized dtype: ", x_data_tokenized.dtype)
                print("test x_data_tokenized!=None: ", x_data_tokenized!=None)
            
            result_embedding=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=cond_scales ,
                x_data=None, 
                # ++
                x_data_tokenized=x_data_tokenized,
                #
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,device=device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X,
                Xnormfac=Xnormfac,
                max_length=max_length,
            )
            
        # # --
        # result=torch.round(result*ynormfac) # (batch=1, channel=1, seq_len)
        
        # ++ for pLM
        # full record
        # result_embedding as image.dim: [batch, channels, seq_len]
        # result_tokens.dim: [batch, seq_len]
        if image_channels==33:
            result_tokens,result_logits = convert_into_tokens_using_prob(
                result_embedding,
                pLM_Model_Name,
            )
        else:
            result_tokens,result_logits = convert_into_tokens(
                pLM_Model, 
                result_embedding,
                pLM_Model_Name,
            )
        # +++++++++++++++++++++++++++++++++
        result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len]
        
        # + for debug
        print('result dim: ', result.shape)
        
        # plot sequence token code: esm (33 tokens)
        fig=plt.figure()
        plt.plot (
            result[0,0,:].cpu().detach().numpy(),
            label= f'Predicted'
        )
        #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}')
        plt.legend()
        outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg"
        #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}")
        if IF_showfig==1:
            plt.show ()
        else:
            plt.savefig(outname, dpi=200)
        plt.close()
        
        # 
        # # --
        # to_rev=result[:,0,:] 
        # to_rev=to_rev.long().cpu().detach().numpy()
        # print (to_rev.shape)
        # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)
        # 
        # for iii in range (len(y_data_reversed)):
        #     y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
        
        # # ++: for model A: no mask is provided from input
        # # reverse the PREDICTED y into a foldable sequence
        # # save this block for Model A 
        # to_rev=result[:,0,:] # token (batch,seq_len)
        # y_data_reversed=decode_many_ems_token_rec_for_folding(
        #     to_rev,
        #     result_logits,
        #     pLM_alphabet,
        #     pLM_Model,
        # )
        
        # if CKeys['Debug_TrainerPack']==3:
        #     print("on foldable result: ", to_rev[0])
        #     print("on result_logits: ", result_logits[0])
        #     a = decode_one_ems_token_rec_for_folding(
        #         to_rev[0],
        #         result_logits[0],
        #         pLM_alphabet,
        #         pLM_Model,
        #     )
        #     print('One resu: ', a)
        #     print("on y_data_reversed: ", y_data_reversed[0])
        #    print("y_data_reversed.type", y_data_reversed.dtype)
        #
        
        # ++: for model B: using mask from the input
        # extract the mask/seq_len from input if possible
        if tokenizer_X!=None:
            # for SecStr+ModelB
            result_mask = read_mask_from_input(
                tokenized_data=None, 
                mask_value=None,
                seq_data=x_data[iisample],
                max_seq_length=max_length,
            )
        else:
            # for ForcPath+ModelB
            result_mask = read_mask_from_input(
                tokenized_data=x_data_tokenized, # None, 
                mask_value=0, # None,
                seq_data=None, # x_data[iisample],
                max_seq_length=None, # max_length,
            )
        
        to_rev=result[:,0,:] # token (batch,seq_len)
        if CKeys['Debug_TrainerPack']==3:
            print("on foldable result: ", to_rev[0])
            print("on result_logits: ", result_logits[0])
            print("on mask: ", result_mask[0])
            a = decode_one_ems_token_rec_for_folding_with_mask(
                to_rev[0],
                result_logits[0],
                pLM_alphabet,
                pLM_Model,
                result_mask[0],
            )
            print('One resu: ', a)

        y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask(
            to_rev,
            result_logits,
            pLM_alphabet,
            pLM_Model,
            result_mask,
        )
        if CKeys['Debug_TrainerPack']==3:
            print("on y_data_reversed[0]: ", y_data_reversed[0])
            
            
        
        ### reverse second structure input....
        if X_cond != None:
            X_cond=torch.round(X_cond*Xnormfac)

            to_rev=X_cond[:,:] 
            to_rev=to_rev.long().cpu().detach().numpy()
            print (to_rev.shape)
            X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)

            for iii in range (len(y_data_reversed)):
                X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
        if x_data !=None:
            X_data_reversed=x_data #is already in sequence fromat..
               

        # print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence", y_data_reversed[iisample])
        print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed)
        
        # + for debug
        print("================================================")
        print("foldproteins: ", foldproteins)
        
        if not foldproteins:
            pdb_file=None
            
        else:
            
            if X_cond != None:
                xbc=X_cond[iisample,:].cpu().detach().numpy()
                out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}'
            if x_data !=None:
                #xbc=x_data[iisample] 
                # ----------------------------------
                # this one can be too long for a name
                out_nam=x_data[iisample] 
                # ++++++++++++++++++++++++++++++++++
                # 
                out_nam=iisample
             
            
            tempname='temp'
            pdb_file, fasta_file=foldandsavePDB_pdb_fasta (
                sequence=y_data_reversed[0], 
                filename_out=tempname, 
                num_cycle=num_cycle, 
                flag=flag,
                # +++++++++++++++++++
                # prefix=prefix,
                prefix=sample_dir,
            )

            # out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta'
            # ------------------------------------------
            # this one can be too long for a name
            # out_nam_fasta=f'{sample_dir}{out_nam}_{flag}_{e}_{iisample}.fasta'
            # ++++++++++++++++++++++++++++++++++++++++++
            out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta'

            write_fasta (y_data_reversed[0], out_nam_fasta)            
            
        
            # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
            # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
            # -------------------------------------------
            # this one can be too long for a name
            # However, the input X is recorded in the code
            # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb'
            # +++++++++++++++++++++++++++++++++++++++++++
            out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb'
            out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta'
            
            # print('Debug 1: out: ', out_nam)
            # print('Debug 2: in: ', pdb_file)
            shutil.copy (pdb_file, out_nam) #source, dest
            shutil.copy (fasta_file, out_nam_fasta)
            # cmd_line = 'cp ' + pdb_file + ' ' + out_nam
            # print(cmd_line)
            # os.popen(cmd_line)
            # print('Debug 3')
            # clean the slade to avoid mistakenly using the previous fasta file
            os.remove (pdb_file)
            os.remove (fasta_file)
            
            pdb_file=out_nam
            fasta_file=out_nam_fasta
            pdb_file_list.append(pdb_file)
            fasta_file_list.append(fasta_file)
            
            
            print (f"Properly named PDB file produced: {pdb_file}")
            if IF_showfig==1:
                #flag=1000
                view=show_pdb(
                    pdb_file=pdb_file, 
                    flag=flag,
                    show_sidechains=show_sidechains, 
                    show_mainchains=show_mainchains, 
                    color=color
                )
                view.show()


        # steps=steps+1
        
    return pdb_file_list, fasta_file_list

#
def sample_sequence_omegafold_ModelB (
    model,
    X=None, #this is the target conventionally when using text embd
    flag=0,
    cond_scales=1.,
    foldproteins=False,
    X_string=None,
    x_data=None,  
    skip_steps=0,
    inpaint_images = None,
    inpaint_masks = None,
    inpaint_resample_times = None,
    init_images = None,
    num_cycle=16,
    # ++++++++++++++++++++++++
    ynormfac=1,
    train_unet_number=1,
    tokenizer_X=None,
    Xnormfac=1.,
    max_length=1.,
    prefix=None,
    tokenizer_y=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
):
    
    # steps=0
    # e=flag

    #num_samples = min (num_samples,y_train_batch.shape[0] )
    if X!=None:
        print (f"Producing {len(X)} samples...from text conditioning X...")
        lenn_val=len(X)
    if X_string!=None:
        lenn_val=len(X_string)
        print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...")
    
    if x_data!=None:
        print (f"Producing {len(x_data)} samples...from image conditingig x_data  ...")
        lenn_val=len(x_data)
        print (x_data)
        
    print ('Device: ', device)

    # + for debug
    print('tot ', lenn_val)
    for iisample in range (lenn_val):
        print("Working on ", iisample)
        X_cond=None  
        if X_string==None and X != None: #only do if X provided
            X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
        if X_string !=None:
            XX = tokenizer_X.texts_to_sequences(X_string[iisample])
            XX= sequence.pad_sequences(XX,  maxlen=max_length, padding='post', truncating='post')  
            XX=np.array(XX)
            X_cond=torch.from_numpy(XX).float()/Xnormfac
            print ('Tokenized and processed: ', X_cond)
        
        print ("X_cond=", X_cond)
        
        # # --
        # result=model.sample ( 
        #     x=X_cond,
        #     stop_at_unet_number=train_unet_number ,
        #     cond_scale=cond_scales ,
        #     x_data=x_data[iisample], 
        #     # ++
        #     x_data_tokenized=
        #     skip_steps=skip_steps,
        #     inpaint_images = inpaint_images,
        #     inpaint_masks = inpaint_masks,
        #     inpaint_resample_times = inpaint_resample_times,
        #     init_images = init_images,device=device,
        #     # ++++++++++++++++++++++++++
        #     tokenizer_X=tokenizer_X,
        #     Xnormfac=Xnormfac,
        #     max_length=max_length,
        #                     )
        # ++
        if tokenizer_X!=None:
            result=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=cond_scales ,
                x_data=x_data[iisample], 
                # ++
                x_data_tokenized=None,
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,device=device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X,
                Xnormfac=Xnormfac,
                max_length=max_length,
            )
        else:
            x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac)
            x_data_tokenized=x_data_tokenized.to(torch.float)
            # + for debug:
            if CKeys['Debug_TrainerPack']==1:
                print("x_data_tokenized dim: ", x_data_tokenized.shape)
                print("x_data_tokenized dtype: ", x_data_tokenized.dtype)
                print("test: ", x_data_tokenized!=None)
            result=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=cond_scales ,
                x_data=None, 
                # ++
                x_data_tokenized=x_data_tokenized,
                #
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,device=device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X,
                Xnormfac=Xnormfac,
                max_length=max_length,
            )
          
            
        result=torch.round(result*ynormfac)
        # + for debug
        print('result dim: ', result.shape)
        
        fig=plt.figure()
        plt.plot (
            result[0,0,:].cpu().detach().numpy(),
            label= f'Predicted'
        )
        #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}')
        plt.legend()
        outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg"
        #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}")
        if IF_showfig==1:
            plt.show ()
        else:
            plt.savefig(outname, dpi=200)
        plt.close()
            

        to_rev=result[:,0,:] 
        to_rev=to_rev.long().cpu().detach().numpy()
        print (to_rev.shape)
        y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

        for iii in range (len(y_data_reversed)):
            y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
        
        ### reverse second structure input....
        if X_cond != None:
            X_cond=torch.round(X_cond*Xnormfac)

            to_rev=X_cond[:,:] 
            to_rev=to_rev.long().cpu().detach().numpy()
            print (to_rev.shape)
            X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)

            for iii in range (len(y_data_reversed)):
                X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
        if x_data !=None:
            X_data_reversed=x_data #is already in sequence fromat..
               

        # print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence", y_data_reversed[iisample])
        print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed)
        
        # + for debug
        print("================================================")
        print("foldproteins: ", foldproteins)
        
        if not foldproteins:
            pdb_file=None
            
        else:
            
            if X_cond != None:
                xbc=X_cond[iisample,:].cpu().detach().numpy()
                out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}'
            if x_data !=None:
                #xbc=x_data[iisample] 
                # ----------------------------------
                # this one can be too long for a name
                out_nam=x_data[iisample] 
                # ++++++++++++++++++++++++++++++++++
                # 
                out_nam=iisample
             
            
            tempname='temp'
            pdb_file=foldandsavePDB (
                sequence=y_data_reversed[0], 
                filename_out=tempname, 
                num_cycle=num_cycle, 
                flag=flag,
                # +++++++++++++++++++
                # prefix=prefix,
                prefix=sample_dir,
            )

            # out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta'
            # ------------------------------------------
            # this one can be too long for a name
            # out_nam_fasta=f'{sample_dir}{out_nam}_{flag}_{e}_{iisample}.fasta'
            # ++++++++++++++++++++++++++++++++++++++++++
            out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta'

            write_fasta (y_data_reversed[0], out_nam_fasta)            
            
        
            # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
            # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
            # -------------------------------------------
            # this one can be too long for a name
            # However, the input X is recorded in the code
            # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb'
            # +++++++++++++++++++++++++++++++++++++++++++
            out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb'
            
            # print('Debug 1: out: ', out_nam)
            # print('Debug 2: in: ', pdb_file)
            shutil.copy (pdb_file, out_nam) #source, dest
            # cmd_line = 'cp ' + pdb_file + ' ' + out_nam
            # print(cmd_line)
            # os.popen(cmd_line)
            # print('Debug 3')
            pdb_file=out_nam           
            
            
            
            print (f"Properly named PDB file produced: {pdb_file}")
            if IF_showfig==1:
                #flag=1000
                view=show_pdb(
                    pdb_file=pdb_file, 
                    flag=flag,
                    show_sidechains=show_sidechains, 
                    show_mainchains=show_mainchains, 
                    color=color
                )
                view.show()


        # steps=steps+1
        
    return pdb_file 

# ++ for de novo input of ForcPath
# ++
def sample_sequence_omegafold_pLM_ModelB_For_ForcPath (
    model,
    X=None, #this is the target conventionally when using text embd
    flag=0,
    cond_scales=[1.], # 1.,
    foldproteins=False,
    X_string=None,
    x_data=None,  
    skip_steps=0,
    inpaint_images = None,
    inpaint_masks = None,
    inpaint_resample_times = None,
    init_images = None,
    num_cycle=16,
    # ++++++++++++++++++++++++
    ynormfac=1,
    train_unet_number=1,
    tokenizer_X=None,
    Xnormfac=1.,
    max_length=1.,
    prefix=None,
    tokenizer_y=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
    # ++
    pLM_Model=None, # pLM_Model,
    pLM_Model_Name=None, # pLM_Model_Name,
    image_channels=None, # image_channels,
    pLM_alphabet=None, # esm_alphabet,
):
    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # prepare input in different channels
    # 
    X_cond=None  # this is for text-conditioning
    if X_string==None and X != None: #only do if X provided
        print (f"Producing {len(X)} samples...from text conditioning X...")
        lenn_val=len(X)
        # shape of X: [[..],[..]]: double bracket
        X_cond=torch.Tensor(X).to(device)
        # --
        # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
    if X_string !=None:
        print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...")
        lenn_val=len(X_string)
        # --
        XX = tokenizer_X.texts_to_sequences(X_string[iisample])
        # ++
        XX = tokenizer_X.texts_to_sequences(X_string)
        XX= sequence.pad_sequences(XX,  maxlen=max_length, padding='post', truncating='post')  
        XX=np.array(XX)
        X_cond=torch.from_numpy(XX).float()/Xnormfac
        print ('Tokenized and processed: ', X_cond)
        
    if x_data!=None:
        print (f"Producing {len(x_data)} samples...from image conditingig x_data  ...")
        lenn_val=len(x_data)
        if tokenizer_X==None: # for ForcPath,
            # need to do Padding and Normalization
            # and then put into tokenized data channel
            x_data_tokenized=[]
            for ii in range(lenn_val):
                x_data_one_line=pad_a_np_arr(x_data[ii], 0.0, max_length)
                x_data_tokenized.append(x_data_one_line)
            x_data_tokenized=np.array(x_data_tokenized)
            x_data_tokenized=torch.from_numpy(x_data_tokenized/Xnormfac)
        else:
            # leave for SecStr case: TBA
            pass
        # print (x_data)
        # ++ for result_mask based on input: x_data or x_data_tokenized
        # ++: for model B: using mask from the input
        # extract the mask/seq_len from input if possible
        if tokenizer_X!=None:
            # for SecStr+ModelB
            result_mask = read_mask_from_input(
                tokenized_data=None, 
                mask_value=None,
                seq_data=x_data, # x_data[iisample],
                max_seq_length=max_length,
            )
        else:
            # for ForcPath+ModelB
            result_mask = read_mask_from_input(
                tokenized_data=x_data_tokenized, # None, 
                mask_value=0, # None,
                seq_data=None, # x_data[iisample],
                max_seq_length=None, # max_length,
            )
            
        
    print ("Input contents:")    
    print ("cond_img condition: x_data=\n", x_data)
    print ("Text condition: X_cond=\n", X_cond)
    
    # store the results
    pdb_file_list=[]
    fasta_file_list=[]
    
    # loop over cond_scales
    for idx_cond, this_cond_scale in enumerate(cond_scales):
        print(f"Working on cond_scale {str(this_cond_scale)}")
        # do sampling
        # -----------------------------------------------------------------
        # for below, two branches are all for cond_img, not for text_cond
        if tokenizer_X!=None:
            # for SecStr+ModelB, not test here
            result_embedding=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=this_cond_scale, # cond_scales ,
                x_data=x_data, # x_data[iisample],  # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim
                # ++
                x_data_tokenized=None,
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,device=device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X,
                Xnormfac=Xnormfac,
                max_length=max_length,
            )
        else:
            # for ForcPath+ModelB:
            # for model.sample() here using x_data_tokenized channel
            x_data_tokenized=x_data_tokenized.to(torch.float) # shape [batch, max_seq_len]
            # leave channel expansion for the self.sample() to handle
            
            # + for debug:
            if CKeys['Debug_TrainerPack']==3:
                print("x_data_tokenized dim: ", x_data_tokenized.shape)
                print("x_data_tokenized dtype: ", x_data_tokenized.dtype)
                print("test x_data_tokenized!=None: ", x_data_tokenized!=None)
            
            result_embedding=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=this_cond_scale, # cond_scales ,
                x_data=None, 
                # ++
                x_data_tokenized=x_data_tokenized,
                #
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,device=device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X,
                Xnormfac=Xnormfac,
                max_length=max_length,
            )
            
        # handle the results: from embedding into AA  
        # ++ for pLM
        if image_channels==33:
            # pass
            result_tokens,result_logits = convert_into_tokens_using_prob(
                result_embedding,
                pLM_Model_Name,
            )
        else:
            # full record
            # result_embedding as image.dim: [batch, channels, seq_len]
            # result_tokens.dim: [batch, seq_len]
            result_tokens,result_logits = convert_into_tokens(
                pLM_Model, 
                result_embedding,
                pLM_Model_Name,
            )
        # +++++++++++++++++++++++++++++++++
        result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len]
        
        # + for debug
        print('result dim: ', result.shape)
        
        # plot sequence token code: esm (33 tokens), for one batch
        fig=plt.figure()
        for ii in range(lenn_val):
            plt.plot (
                result[ii,0,:].cpu().detach().numpy(),
                label= f'Predicted for Input#{str(ii)}'
            )
        #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}')
        plt.legend()
        outname = sample_dir+ f"DenovoInputXs_CondScale_No{str(idx_cond)}_Val_{str(this_cond_scale)}_{e}_{steps}.jpg"
        #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}")
        if IF_showfig==1:
            plt.show ()
        else:
            plt.savefig(outname, dpi=200)
        plt.close()
        
        # translate result into AA
        to_rev=result[:,0,:] # token (batch,seq_len)
        if CKeys['Debug_TrainerPack']==3:
            print("on foldable result: ", to_rev[0])
            print("on result_logits: ", result_logits[0])
            print("on mask: ", result_mask[0])
            a = decode_one_ems_token_rec_for_folding_with_mask(
                to_rev[0],
                result_logits[0],
                pLM_alphabet,
                pLM_Model,
                result_mask[0],
            )
            print('One resu: ', a)

        y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask(
            to_rev,
            result_logits,
            pLM_alphabet,
            pLM_Model,
            result_mask,
        )
        if CKeys['Debug_TrainerPack']==3:
            print("on y_data_reversed[0]: ", y_data_reversed[0])
            
        ### reverse second structure input....
        if X_cond != None:
            X_cond=torch.round(X_cond*Xnormfac)

            to_rev=X_cond[:,:] 
            to_rev=to_rev.long().cpu().detach().numpy()
            print (to_rev.shape)
            X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)

            for iii in range (len(y_data_reversed)):
                X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
        if x_data !=None:
            # work for second structure input....
            # work for ForcPath input...
            X_data_reversed=x_data #is already in sequence fromat..
        
        # sections for each one result
        for iisample in range(lenn_val):
            print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample])
            
            out_nam_fasta=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.fasta'
            write_fasta (y_data_reversed[iisample], out_nam_fasta) 
            fasta_file_list.append(out_nam_fasta)
        
            # + for debug
            print("================================================")
            print("foldproteins: ", foldproteins)
            
            if not foldproteins:
                pdb_file=None

            else:

                if X_cond != None:
                    # not maintained
                    xbc=X_cond[iisample,:].cpu().detach().numpy()
                    out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})+f'_{flag}_{steps}'
                if x_data !=None:
                    pass
                    # #xbc=x_data[iisample] 
                    # # ----------------------------------
                    # # this one can be too long for a name
                    # out_nam=x_data[iisample] 
                    # # ++++++++++++++++++++++++++++++++++
                    # # 
                    # out_nam=iisample

                tempname='temp'
                pdb_file, fasta_file=foldandsavePDB_pdb_fasta (
                    sequence=y_data_reversed[iisample], 
                    filename_out=tempname, 
                    num_cycle=num_cycle, 
                    flag=flag,
                    # +++++++++++++++++++
                    # prefix=prefix,
                    prefix=sample_dir,
                )         


                # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
                # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
                # -------------------------------------------
                # this one can be too long for a name
                # However, the input X is recorded in the code
                # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb'
                # +++++++++++++++++++++++++++++++++++++++++++
                out_nam=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.pdb'
                # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta'

                # print('Debug 1: out: ', out_nam)
                # print('Debug 2: in: ', pdb_file)
                shutil.copy (pdb_file, out_nam) #source, dest
                # shutil.copy (fasta_file, out_nam_fasta)
                # cmd_line = 'cp ' + pdb_file + ' ' + out_nam
                # print(cmd_line)
                # os.popen(cmd_line)
                # print('Debug 3')
                # clean the slade to avoid mistakenly using the previous fasta file
                os.remove (pdb_file)
                os.remove (fasta_file)

                pdb_file=out_nam
                # fasta_file=out_nam_fasta
                pdb_file_list.append(pdb_file)
                
                # ++ write the input condtion as a reference: for ForcPath
                out_nam_inX=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}_input.txt'
                if torch.is_tensor(X_data_reversed[iisample]):
                    # for safety, not used usually
                    xbc=X_data_reversed[iisample].cpu().detach().numpy()
                else:
                    xbc=X_data_reversed[iisample]
                if tokenizer_X==None:
                    # for ForcPath case
                    out_inX=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})
                else:
                    # for SecStr case
                    out_inX=xbc
                with open(out_nam_inX, "w") as inX_file:
                    inX_file.write(out_inX)


                print (f"Properly named PDB file produced: {pdb_file}")
                if IF_showfig==1:
                    #flag=1000
                    view=show_pdb(
                        pdb_file=pdb_file, 
                        flag=flag,
                        show_sidechains=show_sidechains, 
                        show_mainchains=show_mainchains, 
                        color=color
                    )
                    view.show()            
    
        
    return pdb_file_list, fasta_file_list

# ++
def sample_sequence_pLM_ModelB_For_ForcPath_Predictor (
    model,
    X=None, #this is the target conventionally when using text embd
    flag=0,
    cond_scales=[1.], # 1.,
    foldproteins=False,
    X_string=None,
    x_data=None,  
    skip_steps=0,
    inpaint_images = None,
    inpaint_masks = None,
    inpaint_resample_times = None,
    init_images = None,
    num_cycle=16,
    # ++++++++++++++++++++++++
    ynormfac=1,
    train_unet_number=1,
    tokenizer_X=None,
    Xnormfac=1.,
    max_length=1.,
    prefix=None,
    tokenizer_y=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
    # ++
    pLM_Model=None, # pLM_Model,
    pLM_Model_Name=None, # pLM_Model_Name,
    image_channels=None, # image_channels,
    pLM_alphabet=None, # esm_alphabet,
    # ++
    esm_layer=None,
):
    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # input: a list of AA sequence in string format
    # output: ForcPath prediction
    #
    # 1. decide input channel: text_cond or img_cond
    X_cond=None # this is for text-conditioning
    if X_string==None and X != None: #only do if X provided
        print (f"Producing {len(X)} samples...from text conditioning X...")
        lenn_val=len(X)
        # shape of X: [[..],[..]]: double bracket
        X_cond=torch.Tensor(X).to(device)
        # --
        # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
    if X_string !=None:
        print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...")
        lenn_val=len(X_string)
        # --
        XX = tokenizer_X.texts_to_sequences(X_string[iisample])
        # ++
        XX = tokenizer_X.texts_to_sequences(X_string)
        XX= sequence.pad_sequences(XX,  maxlen=max_length, padding='post', truncating='post')  
        XX=np.array(XX)
        X_cond=torch.from_numpy(XX).float()/Xnormfac
        print ('Tokenized and processed: ', X_cond)
        
        
    if x_data!=None: # this is for img_conditioning channel
        # Will use this channel
        print (f"Producing {len(x_data)} samples...from image conditingig x_data  ...")
        lenn_val=len(x_data)
        seq_len_list=[]
        for this_AA in x_data:
            seq_len_list.append(len(this_AA))
            
        
    print ("Input contents:")    
    print ("cond_img condition: x_data=\n", x_data)
    print ("Text condition: X_cond=\n", X_cond)
    
    # 2. perform sampling
    # loop over cond_scales
    resu_prediction={}
    
    for idx_cond, this_cond_scale in enumerate(cond_scales):
        print(f"Working on cond_scale {str(this_cond_scale)}")
        # leave the translation from seq to tokenized 
        # in the model.sample function using x_data channel
        # Need to pass on the esm model part or tokenizer_X
        # 
        result_embedding=model.sample ( 
            x=X_cond,
            stop_at_unet_number=train_unet_number ,
            cond_scale=this_cond_scale, # cond_scales ,
            x_data=x_data, # x_data[iisample],  # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim
            # ++
            x_data_tokenized=None,
            skip_steps=skip_steps,
            inpaint_images = inpaint_images,
            inpaint_masks = inpaint_masks,
            inpaint_resample_times = inpaint_resample_times,
            init_images = init_images,device=device,
            # ++++++++++++++++++++++++++
            tokenizer_X=tokenizer_X,
            Xnormfac=Xnormfac,
            max_length=max_length,
            # ++ for esm
            pLM_Model=pLM_Model,
            pLM_alphabet=pLM_alphabet,
            esm_layer=esm_layer,
            pLM_Model_Name=pLM_Model_Name,
            # image_channels=image_channels,
        )
        
        # convert into prediction
        # 3. translate prediction into something meaningful
        #    consider channel average and masking
        # average across channels
        result_embedding=torch.mean(result_embedding, 1) # (batch, seq_len)
        # read mask from input: X_train_batch_picked (batch, seq_len)
        # result_mask looks like, 0,1,1,...,1,0,0
        # will fill 0th component be zero
        result_mask = read_mask_from_input(
            tokenized_data=None, # X_train_batch[:num_samples], 
            mask_value=0.0,
            seq_data=x_data, # None,
            max_seq_length=max_length, # None,
        )
        # apply mask to result: keep true and zero all false
        # this also make sure 0th components are zero ACCIDENTLY
        result = result_embedding.cpu()*result_mask # (batch, seq_len)
        # result = result.cpu()
        y_data_reversed = result*ynormfac
        # 4. translate the results into a list
        prediction_list = []
        for ii in range(len(x_data)):
            prediction_list.append(
                y_data_reversed[ii, :seq_len_list[ii]+1]
            )
        if CKeys['Debug_TrainerPack']==3:
            print("check prediction dim:")
            print(f"model output: ", y_data_reversed[0])
            print(f"prediction output: ", prediction_list[0])
        # 
        # store the results
        resu_prediction[str(this_cond_scale)]=prediction_list
    
    return resu_prediction,seq_len_list
       
        
        

        
    
    
    
    
#     # --------------------------------------------------------------------------
#     # prepare input in different channels
#     # 
#     X_cond=None  # this is for text-conditioning
#     if X_string==None and X != None: #only do if X provided
#         print (f"Producing {len(X)} samples...from text conditioning X...")
#         lenn_val=len(X)
#         # shape of X: [[..],[..]]: double bracket
#         X_cond=torch.Tensor(X).to(device)
#         # --
#         # X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
#     if X_string !=None:
#         print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...")
#         lenn_val=len(X_string)
#         # --
#         XX = tokenizer_X.texts_to_sequences(X_string[iisample])
#         # ++
#         XX = tokenizer_X.texts_to_sequences(X_string)
#         XX= sequence.pad_sequences(XX,  maxlen=max_length, padding='post', truncating='post')  
#         XX=np.array(XX)
#         X_cond=torch.from_numpy(XX).float()/Xnormfac
#         print ('Tokenized and processed: ', X_cond)
        
#     if x_data!=None:
#         print (f"Producing {len(x_data)} samples...from image conditingig x_data  ...")
#         lenn_val=len(x_data)
#         if tokenizer_X==None: # for ForcPath,
#             # need to do Padding and Normalization
#             # and then put into tokenized data channel
#             x_data_tokenized=[]
#             for ii in range(lenn_val):
#                 x_data_one_line=pad_a_np_arr(x_data[ii], 0.0, max_length)
#                 x_data_tokenized.append(x_data_one_line)
#             x_data_tokenized=np.array(x_data_tokenized)
#             x_data_tokenized=torch.from_numpy(x_data_tokenized/Xnormfac)
#         else:
#             # leave for SecStr case: TBA
#             pass
#         # print (x_data)
#         # ++ for result_mask based on input: x_data or x_data_tokenized
#         # ++: for model B: using mask from the input
#         # extract the mask/seq_len from input if possible
#         if tokenizer_X!=None:
#             # for SecStr+ModelB
#             result_mask = read_mask_from_input(
#                 tokenized_data=None, 
#                 mask_value=None,
#                 seq_data=x_data, # x_data[iisample],
#                 max_seq_length=max_length,
#             )
#         else:
#             # for ForcPath+ModelB
#             result_mask = read_mask_from_input(
#                 tokenized_data=x_data_tokenized, # None, 
#                 mask_value=0, # None,
#                 seq_data=None, # x_data[iisample],
#                 max_seq_length=None, # max_length,
#             )
            
        
#     print ("Input contents:")    
#     print ("cond_img condition: x_data=\n", x_data)
#     print ("Text condition: X_cond=\n", X_cond)
    
#     # store the results
#     pdb_file_list=[]
#     fasta_file_list=[]
    
#     # loop over cond_scales
#     for idx_cond, this_cond_scale in enumerate(cond_scales):
#         print(f"Working on cond_scale {str(this_cond_scale)}")
#         # do sampling
#         # -----------------------------------------------------------------
#         # for below, two branches are all for cond_img, not for text_cond
#         if tokenizer_X!=None:
#             # for SecStr+ModelB, not test here
#             result_embedding=model.sample ( 
#                 x=X_cond,
#                 stop_at_unet_number=train_unet_number ,
#                 cond_scale=this_cond_scale, # cond_scales ,
#                 x_data=x_data, # x_data[iisample],  # will pass through tokenizer_X in this sample(), channels will be matched with self.pred_dim
#                 # ++
#                 x_data_tokenized=None,
#                 skip_steps=skip_steps,
#                 inpaint_images = inpaint_images,
#                 inpaint_masks = inpaint_masks,
#                 inpaint_resample_times = inpaint_resample_times,
#                 init_images = init_images,device=device,
#                 # ++++++++++++++++++++++++++
#                 tokenizer_X=tokenizer_X,
#                 Xnormfac=Xnormfac,
#                 max_length=max_length,
#             )
#         else:
#             # for ForcPath+ModelB:
#             # for model.sample() here using x_data_tokenized channel
#             x_data_tokenized=x_data_tokenized.to(torch.float) # shape [batch, max_seq_len]
#             # leave channel expansion for the self.sample() to handle
            
#             # + for debug:
#             if CKeys['Debug_TrainerPack']==3:
#                 print("x_data_tokenized dim: ", x_data_tokenized.shape)
#                 print("x_data_tokenized dtype: ", x_data_tokenized.dtype)
#                 print("test x_data_tokenized!=None: ", x_data_tokenized!=None)
            
#             result_embedding=model.sample ( 
#                 x=X_cond,
#                 stop_at_unet_number=train_unet_number ,
#                 cond_scale=this_cond_scale, # cond_scales ,
#                 x_data=None, 
#                 # ++
#                 x_data_tokenized=x_data_tokenized,
#                 #
#                 skip_steps=skip_steps,
#                 inpaint_images = inpaint_images,
#                 inpaint_masks = inpaint_masks,
#                 inpaint_resample_times = inpaint_resample_times,
#                 init_images = init_images,device=device,
#                 # ++++++++++++++++++++++++++
#                 tokenizer_X=tokenizer_X,
#                 Xnormfac=Xnormfac,
#                 max_length=max_length,
#             )
            
#         # handle the results: from embedding into AA    
#         # ++ for pLM
#         # full record
#         # result_embedding as image.dim: [batch, channels, seq_len]
#         # result_tokens.dim: [batch, seq_len]
#         result_tokens,result_logits = convert_into_tokens(
#             pLM_Model, 
#             result_embedding,
#             pLM_Model_Name,
#         )
#         # +++++++++++++++++++++++++++++++++
#         result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len]
        
#         # + for debug
#         print('result dim: ', result.shape)
        
#         # plot sequence token code: esm (33 tokens), for one batch
#         fig=plt.figure()
#         for ii in range(lenn_val):
#             plt.plot (
#                 result[ii,0,:].cpu().detach().numpy(),
#                 label= f'Predicted for Input#{str(ii)}'
#             )
#         #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}')
#         plt.legend()
#         outname = sample_dir+ f"DenovoInputXs_CondScale_No{str(idx_cond)}_Val_{str(this_cond_scale)}_{e}_{steps}.jpg"
#         #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}")
#         if IF_showfig==1:
#             plt.show ()
#         else:
#             plt.savefig(outname, dpi=200)
#         plt.close()
        
#         # translate result into AA
#         to_rev=result[:,0,:] # token (batch,seq_len)
#         if CKeys['Debug_TrainerPack']==3:
#             print("on foldable result: ", to_rev[0])
#             print("on result_logits: ", result_logits[0])
#             print("on mask: ", result_mask[0])
#             a = decode_one_ems_token_rec_for_folding_with_mask(
#                 to_rev[0],
#                 result_logits[0],
#                 pLM_alphabet,
#                 pLM_Model,
#                 result_mask[0],
#             )
#             print('One resu: ', a)

#         y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask(
#             to_rev,
#             result_logits,
#             pLM_alphabet,
#             pLM_Model,
#             result_mask,
#         )
#         if CKeys['Debug_TrainerPack']==3:
#             print("on y_data_reversed[0]: ", y_data_reversed[0])
            
#         ### reverse second structure input....
#         if X_cond != None:
#             X_cond=torch.round(X_cond*Xnormfac)

#             to_rev=X_cond[:,:] 
#             to_rev=to_rev.long().cpu().detach().numpy()
#             print (to_rev.shape)
#             X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)

#             for iii in range (len(y_data_reversed)):
#                 X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
#         if x_data !=None:
#             # work for second structure input....
#             # work for ForcPath input...
#             X_data_reversed=x_data #is already in sequence fromat..
        
#         # sections for each one result
#         for iisample in range(lenn_val):
#             print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample])
            
#             out_nam_fasta=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.fasta'
#             write_fasta (y_data_reversed[iisample], out_nam_fasta) 
#             fasta_file_list.append(out_nam_fasta)
        
#             # + for debug
#             print("================================================")
#             print("foldproteins: ", foldproteins)
            
#             if not foldproteins:
#                 pdb_file=None

#             else:

#                 if X_cond != None:
#                     # not maintained
#                     xbc=X_cond[iisample,:].cpu().detach().numpy()
#                     out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})+f'_{flag}_{steps}'
#                 if x_data !=None:
#                     pass
#                     # #xbc=x_data[iisample] 
#                     # # ----------------------------------
#                     # # this one can be too long for a name
#                     # out_nam=x_data[iisample] 
#                     # # ++++++++++++++++++++++++++++++++++
#                     # # 
#                     # out_nam=iisample

#                 tempname='temp'
#                 pdb_file, fasta_file=foldandsavePDB_pdb_fasta (
#                     sequence=y_data_reversed[iisample], 
#                     filename_out=tempname, 
#                     num_cycle=num_cycle, 
#                     flag=flag,
#                     # +++++++++++++++++++
#                     # prefix=prefix,
#                     prefix=sample_dir,
#                 )         


#                 # out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
#                 # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{steps}.pdb'
#                 # -------------------------------------------
#                 # this one can be too long for a name
#                 # However, the input X is recorded in the code
#                 # out_nam=f'{sample_dir}{X_data_reversed[iisample]}_{flag}_{iisample}.pdb'
#                 # +++++++++++++++++++++++++++++++++++++++++++
#                 out_nam=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.pdb'
#                 # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta'

#                 # print('Debug 1: out: ', out_nam)
#                 # print('Debug 2: in: ', pdb_file)
#                 shutil.copy (pdb_file, out_nam) #source, dest
#                 # shutil.copy (fasta_file, out_nam_fasta)
#                 # cmd_line = 'cp ' + pdb_file + ' ' + out_nam
#                 # print(cmd_line)
#                 # os.popen(cmd_line)
#                 # print('Debug 3')
#                 # clean the slade to avoid mistakenly using the previous fasta file
#                 os.remove (pdb_file)
#                 os.remove (fasta_file)

#                 pdb_file=out_nam
#                 # fasta_file=out_nam_fasta
#                 pdb_file_list.append(pdb_file)
                
#                 # ++ write the input condtion as a reference: for ForcPath
#                 out_nam_inX=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}_input.txt'
#                 if torch.is_tensor(X_data_reversed[iisample]):
#                     # for safety, not used usually
#                     xbc=X_data_reversed[iisample].cpu().detach().numpy()
#                 else:
#                     xbc=X_data_reversed[iisample]
#                 if tokenizer_X==None:
#                     # for ForcPath case
#                     out_inX=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})
#                 else:
#                     # for SecStr case
#                     out_inX=xbc
#                 with open(out_nam_inX, "w") as inX_file:
#                     inX_file.write(out_inX)


#                 print (f"Properly named PDB file produced: {pdb_file}")
#                 if IF_showfig==1:
#                     #flag=1000
#                     view=show_pdb(
#                         pdb_file=pdb_file, 
#                         flag=flag,
#                         show_sidechains=show_sidechains, 
#                         show_mainchains=show_mainchains, 
#                         color=color
#                     )
#                     view.show()            
    
        
#     return pdb_file_list, fasta_file_list

# ++
# for ProteinDesigner
def sample_loop_omegafold_pLM_ModelB (
    model,
    train_loader,
    cond_scales=[7.5], #list of cond scales - each sampled...
    num_samples=2, #how many samples produced every time tested.....
    timesteps=100, # not used
    flag=0,
    foldproteins=False,
    use_text_embedd=True,
    skip_steps=0,
    # +++++++++++++++++++
    train_unet_number=1,
    ynormfac=1,
    prefix=None,
    tokenizer_y=None,
    Xnormfac=1,
    tokenizer_X=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
    # ++
    pLM_Model=None,
    pLM_Model_Name=None,
    image_channels=None,
    pLM_alphabet=None,
):
    # =====================================================
    # sample # = num_samples*(# of mini-batches)
    # =====================================================
    # steps=0
    # e=flag
    # for item  in train_loader:
    for idx, item  in enumerate(train_loader):

        X_train_batch= item[0].to(device)
        y_train_batch=item[1].to(device)
        
        # --
        # # ++ for pLM case:
        # if pLM_Model_Name=='None':
        #     # just use the encoded sequence
        #     # y_train_batch_in = y_train_batch.unsqueeze(1)
        #     X_train_batch_in = X_train_batch.unsqueeze(1)
        #     # pass
        # elif pLM_Model_Name=='esm2_t33_650M_UR50D':
        #     # with torch.no_grad():
        #     #     results = pLM_Model(
        #     #         y_train_batch,
        #     #         repr_layers=[33],
        #     #         return_contacts=False,
        #     #     )
        #     # y_train_batch_in = results["representations"][33]
        #     # y_train_batch_in = rearrange(
        #     #     y_train_batch_in, 
        #     #     'b l c -> b c l'
        #     # )
        #     X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1)


        GT=y_train_batch.cpu().detach() 

        GT= GT.unsqueeze(1)
        if num_samples>y_train_batch.shape[0]:
            print("Warning: sampling # > len(mini_batch)")

        num_samples = min (num_samples,y_train_batch.shape[0] )
        print (f"Producing {num_samples} samples...")
        X_train_batch_picked = X_train_batch[:num_samples,:] # X_train_batch_in[:num_samples ] # 
        print ('After pLM, (TEST) X_batch shape: ', X_train_batch_picked.shape)
                        
        for iisample in range (len (cond_scales)):

            if use_text_embedd:
                result_embedding=model.sample (
                    # x= X_train_batch,
                    x= X_train_batch_picked,
                    stop_at_unet_number=train_unet_number ,
                    cond_scale=cond_scales[iisample], 
                    device=device, 
                    skip_steps=skip_steps
                )
            else:
                result_embedding=model.sample (
                    x= None, 
                    # x_data_tokenized= X_train_batch,
                    x_data_tokenized= X_train_batch_picked, # dim=(batch, seq_len), will extend channels inside .sample(),
                    stop_at_unet_number=train_unet_number ,
                    cond_scale=cond_scales[iisample],
                    device=device,
                    skip_steps=skip_steps
                )
            # ++ for pLM:
            if image_channels==33:
                result_tokens,result_logits = convert_into_tokens_using_prob(
                    result_embedding,
                    pLM_Model_Name,
                )
            else:
                # full record
                # result_embedding as image.dim: [batch, channels, seq_len]
                # result_tokens.dim: [batch, seq_len]
                result_tokens,result_logits = convert_into_tokens(
                    pLM_Model, 
                    result_embedding,
                    pLM_Model_Name,
                )

            # # ---------------------------------
            # result=torch.round(result*ynormfac)
            # GT=torch.round (GT*ynormfac)
            # +++++++++++++++++++++++++++++++++
            result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len]
            
            # +
                        # # -------------------------------------------
            # #reverse y sequence
            # to_rev=result[:,0,:]
            # to_rev=to_rev.long().cpu().detach().numpy()
            # 
            # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)
            # 
            # for iii in range (len(y_data_reversed)):
            #     y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
            # ++++++++++++++++++++++++++++++++++++++++++++
            # extract the mask/seq_len from input if possible
            # here, from dataloader, we only use tokenized_data for mask generation
            result_mask = read_mask_from_input(
                tokenized_data=X_train_batch[:num_samples], 
                mask_value=0,
                seq_data=None,
                max_seq_length=None,
            )
            to_rev=result[:,0,:] # token (batch,seq_len)
            if CKeys['Debug_TrainerPack']==3:
                print("on foldable result: ", to_rev[0])
                print("on result_logits: ", result_logits[0])
                print("on mask: ", result_mask[0])
                a = decode_one_ems_token_rec_for_folding_with_mask(
                    to_rev[0],
                    result_logits[0],
                    pLM_alphabet,
                    pLM_Model,
                    result_mask[0],
                )
                print('One resu: ', a)

            y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask(
                to_rev,
                result_logits,
                pLM_alphabet,
                pLM_Model,
                result_mask,
            )
            if CKeys['Debug_TrainerPack']==3:
                print("on y_data_reversed[0]: ", y_data_reversed[0])

            # # ++++++++++++++++++++++++++++++++++++++++++++
            # # reverse the PREDICTED y into a foldable sequence
            # # save this block for Model A 
            # to_rev=result[:,0,:] # token (batch,seq_len)
            # y_data_reversed=decode_many_ems_token_rec_for_folding(
            #     to_rev,
            #     result_logits,
            #     pLM_alphabet,
            #     pLM_Model,
            # )
            # if CKeys['Debug_TrainerPack']==3:
            #     print("on foldable result: ", to_rev[0])
            #     print("on result_logits: ", result_logits[0])
            #     a = decode_one_ems_token_rec_for_folding(
            #         to_rev[0],
            #         result_logits[0],
            #         pLM_alphabet,
            #         pLM_Model,
            #     )
            #     print('One resu: ', a)
            #     print("on y_data_reversed: ", y_data_reversed[0])


            # # -----------------------------------------------------
            # #reverse GT_y sequence
            # to_rev=GT[:,0,:]
            # to_rev=to_rev.long().cpu().detach().numpy()
            # 
            # GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)
            # 
            # for iii in range (len(y_data_reversed)):
            #     GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "")
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
            #reverse GT_y sequence
            # GT should be SAFE to reverse
            to_rev=GT[:,0,:] # (batch,1,seq_len)->(batch, seq_len)
            GT_y_data_reversed=decode_many_ems_token_rec(
                to_rev,
                pLM_alphabet,
            )


            # -- not for SecStr anymore
            # ### reverse second structure input....
            # to_rev=torch.round (X_train_batch[:,:]*Xnormfac)
            # to_rev=to_rev.long().cpu().detach().numpy()
            # ++ 
            ### reverse general float input...
            to_rev=X_train_batch[:,:]*Xnormfac
            to_rev=to_rev.cpu().detach().numpy()
            # here, assume X_train_batch is for cond_img: there are padding at both beginning and ending part
            # so, first move the 0th padding to the end:
            # Note:
            # 1. this is good for SecStr case: (not maintained here)
            # 2. this is not good for ForcPath, but can be cued in MD postprocess since the first component will always be 0
            n_batch=to_rev.shape[0]
            n_embed=to_rev.shape[1]
            to_rev_1 = np.zeros(to_rev.shape)
            to_rev_1[:,0:n_embed-1]=to_rev[:,1:n_embed]

            # ++ different input
            if tokenizer_X!=None:
                # change into int
                to_rev_1 = np.round(to_rev_1)
                X_data_reversed=tokenizer_X.sequences_to_texts (to_rev_1)
                for iii in range (len(y_data_reversed)):
                    X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
            else:
                X_data_reversed=to_rev_1.copy()
            # + for debug
            if CKeys['Debug_TrainerPack']==1:
                print("X_data_reversed: ", X_data_reversed)
                    

            for samples in range  (num_samples):
                print ("sample ", samples+1, "out of ", num_samples)

                fig=plt.figure()
                plt.plot (
                    result[samples,0,:].cpu().detach().numpy(),
                    label= f'Predicted'
                )
                plt.plot (
                    GT[samples,0,:],
                    label= f'GT {0}'
                )
                plt.legend()
                outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                if IF_showfig==1:
                    plt.show()
                else:
                    plt.savefig(outname, dpi=200)
                plt.close ()
                
#                 # # -------------------------------------------
#                 # #reverse y sequence
#                 # to_rev=result[:,0,:]
#                 # to_rev=to_rev.long().cpu().detach().numpy()
#                 # 
#                 # y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)
#                 # 
#                 # for iii in range (len(y_data_reversed)):
#                 #     y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
#                 # ++++++++++++++++++++++++++++++++++++++++++++
#                 # extract the mask/seq_len from input if possible
#                 # here, from dataloader, we only use tokenized_data for mask generation
#                 result_mask = read_mask_from_input(
#                     tokenized_data=X_train_batch[:num_samples], 
#                     mask_value=0,
#                     seq_data=None,
#                     max_seq_length=None,
#                 )
#                 to_rev=result[:,0,:] # token (batch,seq_len)
#                 if CKeys['Debug_TrainerPack']==3:
#                     print("on foldable result: ", to_rev[0])
#                     print("on result_logits: ", result_logits[0])
#                     print("on mask: ", result_mask[0])
#                     a = decode_one_ems_token_rec_for_folding_with_mask(
#                         to_rev[0],
#                         result_logits[0],
#                         pLM_alphabet,
#                         pLM_Model,
#                         result_mask[0],
#                     )
#                     print('One resu: ', a)
                    
#                 y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask(
#                     to_rev,
#                     result_logits,
#                     pLM_alphabet,
#                     pLM_Model,
#                     result_mask,
#                 )
#                 if CKeys['Debug_TrainerPack']==3:
#                     print("on y_data_reversed[0]: ", y_data_reversed[0])
                    
#                 # # ++++++++++++++++++++++++++++++++++++++++++++
#                 # # reverse the PREDICTED y into a foldable sequence
#                 # # save this block for Model A 
#                 # to_rev=result[:,0,:] # token (batch,seq_len)
#                 # y_data_reversed=decode_many_ems_token_rec_for_folding(
#                 #     to_rev,
#                 #     result_logits,
#                 #     pLM_alphabet,
#                 #     pLM_Model,
#                 # )
#                 # if CKeys['Debug_TrainerPack']==3:
#                 #     print("on foldable result: ", to_rev[0])
#                 #     print("on result_logits: ", result_logits[0])
#                 #     a = decode_one_ems_token_rec_for_folding(
#                 #         to_rev[0],
#                 #         result_logits[0],
#                 #         pLM_alphabet,
#                 #         pLM_Model,
#                 #     )
#                 #     print('One resu: ', a)
#                 #     print("on y_data_reversed: ", y_data_reversed[0])
                
                
#                 # # -----------------------------------------------------
#                 # #reverse GT_y sequence
#                 # to_rev=GT[:,0,:]
#                 # to_rev=to_rev.long().cpu().detach().numpy()
#                 # 
#                 # GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)
#                 # 
#                 # for iii in range (len(y_data_reversed)):
#                 #     GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "")
#                 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
#                 #reverse GT_y sequence
#                 # GT should be SAFE to reverse
#                 to_rev=GT[:,0,:]
#                 GT_y_data_reversed=decode_many_ems_token_rec(
#                     to_rev,
#                     pLM_alphabet,
#                 )
                

#                 ### reverse second structure input....
#                 to_rev=torch.round (X_train_batch[:,:]*Xnormfac)
#                 to_rev=to_rev.long().cpu().detach().numpy()
#                 # here, assume X_train_batch is for cond_img: there are padding at both beginning and ending part
#                 # so, first move the 0th padding to the end
#                 n_batch=to_rev.shape[0]
#                 n_embed=to_rev.shape[1]
#                 to_rev_1 = np.zeros(to_rev.shape)
#                 to_rev_1[:,0:n_embed-1]=to_rev[:,1:n_embed]
                
#                 # ++ different input
#                 if tokenizer_X!=None:
#                     X_data_reversed=tokenizer_X.sequences_to_texts (to_rev_1)
#                     for iii in range (len(y_data_reversed)):
#                         X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
#                 else:
#                     X_data_reversed=to_rev_1.copy()
#                 # + for debug
#                 if CKeys['Debug_TrainerPack']==1:
#                     print("X_data_reversed: ", X_data_reversed)
                

                # print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples])
                print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} \nor\n {X_data_reversed[samples]}, ")
                print (f"predicted sequence: {y_data_reversed[samples]}")
                print (f"Ground truth:       {GT_y_data_reversed[samples]}")
                error=string_diff (y_data_reversed[samples], GT_y_data_reversed[samples])/len (GT_y_data_reversed[samples])
                print(f"Recovery ratio(Ref): {1.-error}")
                
                # move some
                # # -- X_train_batch is normalized
                # xbc=X_train_batch[samples,:].cpu().detach().numpy()
                # # out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})
                # out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})
                # ++
                xbc = X_data_reversed[samples]
                if type(xbc)==str:
                    out_nam_content=xbc
                else:
                    out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})
                # 1. write out the input X in the dataloder
                out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt'
                # + write the condition clearly
                # X_data_reversed: an array
                with open(out_nam_inX, "w") as inX_file:
                    # inX_file.write(f'{X_data_reversed[samples]}\n')
                    inX_file.write(out_nam_content)
                # 2. write out the predictions
                out_nam_OuY_PR=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}_predict.fasta'
                with open(out_nam_OuY_PR, "w") as ouY_fasta:
                    ouY_fasta.write(f">Predicted\n")
                    ouY_fasta.write(y_data_reversed[samples])
                # 3. Only for dataloader: write out the recovered ground truth
                out_nam_OuY_GT=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}_recGT.fasta'
                with open(out_nam_OuY_GT, "w") as ouY_fasta:
                    ouY_fasta.write(f">reconstructed GT, recoverabliblity: {1.-error}\n")
                    ouY_fasta.write(GT_y_data_reversed[samples])
                
                

                if foldproteins:
                    
                    tempname='temp'
                    pdb_file,fasta_file=foldandsavePDB_pdb_fasta (
                        sequence=y_data_reversed[samples], 
                        filename_out=tempname, 
                        num_cycle=16, flag=flag,
                        # +++++++++++++++++++
                        prefix=prefix
                    )

                    # #out_nam=f'{prefix}{out_nam}.pdb'
                    # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb'
                    # ------------------------------------------------------
                    # sometime, this name below can get too long to fit
                    # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb'
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # add a way to save the sampling name and results
                    # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                    out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb'
                    out_nam_seq=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.fasta'
                    
                    
                    if CKeys['Debug_TrainerPack']==1:
                        print("pdb_file: ", pdb_file)
                        print("out_nam: ", out_nam)
                        
                    print (f'Original PDB: {pdb_file} OUT: {out_nam}')
                    shutil.copy (pdb_file, out_nam) #source, dest
                    shutil.copy (fasta_file, out_nam_seq)
                    
                   
                    # clean the slade to avoid mistakenly using the previous fasta file
                    os.remove (pdb_file)
                    os.remove (fasta_file)
                    
                    
                    pdb_file=out_nam
                    print (f"Properly named PDB file produced: {pdb_file}")
                    print (f"input X for sampling stored: {pdb_file}")
                    
                    if IF_showfig==1:
                        view=show_pdb(
                            pdb_file=pdb_file, 
                            flag=flag, 
                            show_sidechains=show_sidechains,  
                            show_mainchains=show_mainchains, 
                            color=color
                        )
                        view.show()
# ++
# For ProteinPredictor
def sample_loop_omegafold_pLM_ModelB_Predictor (
    model,
    train_loader,
    cond_scales=[7.5], #list of cond scales - each sampled...
    num_samples=2, #how many samples produced every time tested.....
    timesteps=100, # not used
    flag=0,
    foldproteins=False,
    use_text_embedd=True,
    skip_steps=0,
    # +++++++++++++++++++
    train_unet_number=1,
    ynormfac=1,
    prefix=None,
    tokenizer_y=None,
    Xnormfac=1,
    tokenizer_X=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
    # ++
    pLM_Model=None,
    pLM_Model_Name=None,
    image_channels=None,
    pLM_alphabet=None,
    # ++
    esm_layer=None,
):
    # =====================================================
    # sample # = num_samples*(# of mini-batches)
    # =====================================================
    # steps=0
    # e=flag
    # for item  in train_loader:
    val_epoch_MSE_list=[]
    resu_pred = {}
    resu_grou = {}
    # 
    for iisample in range (len (cond_scales)):
        # calculate loss for one selected cond_scales
        # 
        val_epoch_MSE=0.
        num_rec=0
        this_prediction = []
        this_groundtruth = []
        # 
        for idx, item  in enumerate(train_loader):

            X_train_batch= item[0].to(device)
            y_train_batch=item[1].to(device)

            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            # 1. adjust the number of sample to collect in each batch
            if num_samples>y_train_batch.shape[0]:
                print("Warning: sampling # > len(mini_batch)")
            num_samples = min (num_samples,y_train_batch.shape[0])
            print (f"Producing {num_samples} samples...")
            X_train_batch_picked = X_train_batch[:num_samples,:] # X_train_batch_in[:num_samples ] # 
            GT=y_train_batch.cpu().detach()
            GT_picked = GT[:num_samples,:]
            # GT_picked = GT_picked.unsqueeze(1)
            
            # 2. prepare if pLM is used at the input end: 
            #    this is done inised model.sample fun via x_data_tokenized channel
            #
            # 3. sample inside the loop of cond_scales
            if use_text_embedd:
                result_embedding=model.sample (
                    # x= X_train_batch,
                    x= X_train_batch_picked,
                    stop_at_unet_number=train_unet_number ,
                    cond_scale=cond_scales[iisample], 
                    device=device, 
                    skip_steps=skip_steps,
                    # ++
                    pLM_Model_Name=pLM_Model_Name,
                    pLM_Model=pLM_Model,
                    pLM_alphabet=pLM_alphabet,
                    esm_layer=esm_layer,
                )
            else:
                result_embedding=model.sample (
                    x= None, 
                    # x_data_tokenized= X_train_batch,
                    x_data_tokenized= X_train_batch_picked, # dim=(batch, seq_len), will extend channels inside .sample(),
                    stop_at_unet_number=train_unet_number ,
                    cond_scale=cond_scales[iisample],
                    device=device,
                    skip_steps=skip_steps,
                    # ++
                    pLM_Model_Name=pLM_Model_Name,
                    pLM_Model=pLM_Model,
                    pLM_alphabet=pLM_alphabet,
                    esm_layer=esm_layer,
                )
            # result_embedding as image.dim: [batch, channels, seq_len]
            #
            # 4. translate prediction into something meaningful
            #    consider channel average and masking
            # average across channels
            result_embedding=torch.mean(result_embedding, 1) # (batch, seq_len)
            # read mask from input: X_train_batch_picked (batch, seq_len)
            # result_mask looks like, 0,1,1,...,1,0,0
            # will fill 0th component be zero
            result_mask = read_mask_from_input(
                tokenized_data=X_train_batch[:num_samples], 
                mask_value=0.0,
                seq_data=None,
                max_seq_length=None,
            )
            # apply mask to result: keep true and zero all false
            result = result_embedding*result_mask # (batch, seq_len)
            result = result.cpu()
            # result = result.unsqueeze(1) # (batch, 1, seq_len)
            # this is ONLY the result from the model, not predictio yet
            # 
            # 5. calculate loss
            with torch.no_grad():
                val_loss_MSE = criterion_MSE_sum(
                    result,
                    GT_picked,
                )
            val_epoch_MSE += val_loss_MSE.item()/GT_picked.shape[1]
            num_rec += len(GT_picked)
            # 
            # 6. convert into prediction
            y_data_reversed = result*ynormfac
            # prepare GT
            GT_y_data_reversed = GT_picked*ynormfac
            # accumulate the results
            for ibat in range (GT_picked.shape[0]):
                this_prediction.append (np.array( y_data_reversed[ibat,:].cpu() ))
                this_groundtruth.append (np.array( GT_y_data_reversed[ibat,:].cpu() ))
            
            # 
            # 5. reverse input to AA sequence... if needed
            # TBA
        
        # for one scal_cond
        # summarize the loss
        TestSet_MSE = val_epoch_MSE/num_rec
        resu_pred[str(cond_scales[iisample])] = this_prediction
        resu_grou[str(cond_scales[iisample])] = this_groundtruth
        
    # store the MSE along cond_scales
    val_epoch_MSE_list.append(TestSet_MSE)
    
    return val_epoch_MSE_list, resu_pred, resu_grou
        

                        
# ++++++++++++++++++++++++++++++++++++++++++++++++
def sample_loop_omegafold_ModelB (
    model,
    train_loader,
    cond_scales=[7.5], #list of cond scales - each sampled...
    num_samples=2, #how many samples produced every time tested.....
    timesteps=100, # not used
    flag=0,
    foldproteins=False,
    use_text_embedd=True,
    skip_steps=0,
    # +++++++++++++++++++
    train_unet_number=1,
    ynormfac=1,
    prefix=None,
    tokenizer_y=None,
    Xnormfac=1,
    tokenizer_X=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
):
    # =====================================================
    # sample # = num_samples*(# of mini-batches)
    # =====================================================
    # steps=0
    # e=flag
    # for item  in train_loader:
    for idx, item  in enumerate(train_loader):

        X_train_batch= item[0].to(device)
        y_train_batch=item[1].to(device)

        GT=y_train_batch.cpu().detach() 

        GT= GT.unsqueeze(1)
        if num_samples>y_train_batch.shape[0]:
            print("Warning: sampling # > len(mini_batch)")

        num_samples = min (num_samples,y_train_batch.shape[0] )
        print (f"Producing {num_samples} samples...")
        X_train_batch_picked = X_train_batch[:num_samples,:]
        print ('(TEST) X_batch shape: ', X_train_batch_picked.shape)

        for iisample in range (len (cond_scales)):

            if use_text_embedd:
                result=model.sample (
                    # x= X_train_batch,
                    x= X_train_batch_picked,
                    stop_at_unet_number=train_unet_number ,
                    cond_scale=cond_scales[iisample], 
                    device=device, 
                    skip_steps=skip_steps
                )
            else:
                result=model.sample (
                    x= None, 
                    # x_data_tokenized= X_train_batch,
                    x_data_tokenized= X_train_batch_picked,
                    stop_at_unet_number=train_unet_number ,
                    cond_scale=cond_scales[iisample],
                    device=device,
                    skip_steps=skip_steps
                )
        
            result=torch.round(result*ynormfac)
            GT=torch.round (GT*ynormfac)

            for samples in range  (num_samples):
                print ("sample ", samples+1, "out of ", num_samples)

                fig=plt.figure()
                plt.plot (
                    result[samples,0,:].cpu().detach().numpy(),
                    label= f'Predicted'
                )
                plt.plot (
                    GT[samples,0,:],
                    label= f'GT {0}'
                )
                plt.legend()
                outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                if IF_showfig==1:
                    plt.show()
                else:
                    plt.savefig(outname, dpi=200)
                plt.close ()

                #reverse y sequence
                to_rev=result[:,0,:]
                to_rev=to_rev.long().cpu().detach().numpy()

                y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

                for iii in range (len(y_data_reversed)):
                    y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")

                #reverse GT_y sequence
                to_rev=GT[:,0,:]
                to_rev=to_rev.long().cpu().detach().numpy()

                GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

                for iii in range (len(y_data_reversed)):
                    GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "")

                ### reverse second structure input....
                to_rev=torch.round (X_train_batch[:,:]*Xnormfac)
                to_rev=to_rev.long().cpu().detach().numpy()
                
                # ++ different input
                if tokenizer_X!=None:
                    X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)
                    for iii in range (len(y_data_reversed)):
                        X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
                else:
                    X_data_reversed=to_rev.copy()
                # + for debug
                if CKeys['Debug_TrainerPack']==1:
                    print("X_data_reversed: ", X_data_reversed)
                

                print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples])
                print (f"Ground truth: {GT_y_data_reversed[samples]}")

                if foldproteins:
                    xbc=X_train_batch[samples,:].cpu().detach().numpy()
                    out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})
                    tempname='temp'
                    pdb_file=foldandsavePDB (
                        sequence=y_data_reversed[samples], 
                        filename_out=tempname, 
                        num_cycle=16, flag=flag,
                        # +++++++++++++++++++
                        prefix=prefix
                    )

                    # #out_nam=f'{prefix}{out_nam}.pdb'
                    # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb'
                    # ------------------------------------------------------
                    # sometime, this name below can get too long to fit
                    # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb'
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # add a way to save the sampling name and results
                    # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                    out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb'
                    out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt'
                    
                    if CKeys['Debug_TrainerPack']==1:
                        print("pdb_file: ", pdb_file)
                        print("out_nam: ", out_nam)
                        
                    print (f'Original PDB: {pdb_file} OUT: {out_nam}')
                    shutil.copy (pdb_file, out_nam) #source, dest
                    # +
                    with open(out_nam_inX, "w") as inX_file:
                        inX_file.write(f'{X_data_reversed[samples]}\n')
                        
                    pdb_file=out_nam
                    print (f"Properly named PDB file produced: {pdb_file}")
                    print (f"input X for sampling stored: {pdb_file}")
                    
                    if IF_showfig==1:
                        view=show_pdb(
                            pdb_file=pdb_file, 
                            flag=flag, 
                            show_sidechains=show_sidechains,  
                            show_mainchains=show_mainchains, 
                            color=color
                        )
                        view.show()

#                 steps=steps+1
                
#         if steps>num_samples:
#             break

# 
#
def sample_loop_FromModelB (model,
                train_loader,
                cond_scales=[7.5], #list of cond scales - each sampled...
                num_samples=2, #how many samples produced every time tested.....
                timesteps=100,
                 flag=0,foldproteins=False,
                 use_text_embedd=True,skip_steps=0,
                 # +++++++++++++++++++
                 train_unet_number=1,
                 ynormfac=1,
                 prefix=None,
                 tokenizer_y=None,
                 Xnormfac=1,
                 tokenizer_X=None,
                 
               ):
    steps=0
    e=flag
    for item  in train_loader:

            X_train_batch= item[0].to(device)
            y_train_batch=item[1].to(device)

            GT=y_train_batch.cpu().detach() 
                    
            GT= GT.unsqueeze(1)
            num_samples = min (num_samples,y_train_batch.shape[0] )
            print (f"Producing {num_samples} samples...")
            
            print ('X_train_batch shape: ', X_train_batch.shape)

            for iisample in range (len (cond_scales)):
                
                if use_text_embedd:
                    result=model.sample (x= X_train_batch,stop_at_unet_number=train_unet_number ,
                                         cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps)
                else:
                    result=model.sample (x= None, x_data_tokenized= X_train_batch,
                                         stop_at_unet_number=train_unet_number ,
                                         cond_scale=cond_scales[iisample],device=device,skip_steps=skip_steps)
                    
                result=torch.round(result*ynormfac)
                GT=torch.round (GT*ynormfac)

                for samples in range  (num_samples):
                    print ("sample ", samples, "out of ", num_samples)
                    
                    plt.plot (result[samples,0,:].cpu().detach().numpy(),label= f'Predicted')
                    plt.plot (GT[samples,0,:],label= f'GT {0}')
                    plt.legend()

                    outname = prefix+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                   
                    plt.savefig(outname, dpi=200)
                    plt.show ()
                    
                    #reverse y sequence
                    to_rev=result[:,0,:]
                    to_rev=to_rev.long().cpu().detach().numpy()
                    
                    y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

                    for iii in range (len(y_data_reversed)):
                        y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
                        
                    #reverse GT_y sequence
                    to_rev=GT[:,0,:]
                    to_rev=to_rev.long().cpu().detach().numpy()
                    
                    GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

                    for iii in range (len(y_data_reversed)):
                        GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "")
                    
                    ### reverse second structure input....
                    to_rev=torch.round (X_train_batch[:,:]*Xnormfac)
                    to_rev=to_rev.long().cpu().detach().numpy()
                   
                    X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)

                    for iii in range (len(y_data_reversed)):
                        X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")

                    print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, predicted sequence: ", y_data_reversed[samples])
                    print (f"Ground truth: {GT_y_data_reversed[samples]}")
                   
                    if foldproteins:
                        xbc=X_train_batch[samples,:].cpu().detach().numpy()
                        out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})
                        tempname='temp'
                        pdb_file=foldandsavePDB (
                            sequence=y_data_reversed[samples], 
                            filename_out=tempname, 
                            num_cycle=16, flag=flag,
                            # +++++++++++++++++++
                            prefix=prefix
                        )
                        
                        #out_nam=f'{prefix}{out_nam}.pdb'
                        out_nam=f'{prefix}{X_data_reversed[samples]}.pdb'
                        print (f'Original PDB: {pdb_file} OUT: {out_nam}')
                        shutil.copy (pdb_file, out_nam) #source, dest
                        pdb_file=out_nam
                        print (f"Properly named PDB file produced: {pdb_file}")
                        
                        view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains,  show_mainchains=show_mainchains, color=color)
                        view.show()

                    steps=steps+1
            if steps>num_samples:
                break
        
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# train_loop tasks:
# 1. calculate loss for one batch
# 2. call sample loop
# 3. call sample sequence
# 4. print records and save model
# ===============================================================
# for ProteinDesigner_B
# 1. expanded for Probability case
# ++
cal_norm_prob = nn.Softmax(dim=2)

def train_loop_Model_B (
    model,
    train_loader,
    test_loader,
    #
    optimizer=None,
    print_every=10,
    epochs= 300,
    start_ep=0,
    start_step=0,
    train_unet_number=1,
    print_loss_every_steps=1000,
    #
    trainer=None,
    plot_unscaled=False,
    max_batch_size=4,
    save_model=False,
    cond_scales=[7.5], #list of cond scales - each sampled...
    num_samples=2, #how many samples produced every time tested.....
    foldproteins=False,
    cond_image=False, #use cond_images...
    # add some
    # +++++++++++++++++++++++++++
    device=None,
    loss_list=[],
    epoch_list=[],
    train_hist_file=None,
    train_hist_file_full=None,
    prefix=None,
    Xnormfac=1.,
    ynormfac=1.,
    tokenizer_X=None,
    tokenizer_y=None,
    test_condition_list=[],
    max_length=1,
    CKeys=None,
    sample_steps=1,
    sample_dir=None,
    save_every_epoch=1,
    save_point_info_file=None,
    store_dir=None,
    # ++
    pLM_Model_Name=None,
    image_channels=None,
    print_error=False,
):
    
    if not exists (trainer):
        if not exists (optimizer):
            print ("ERROR: If trainer not used, need to provide optimizer.")
    if exists (trainer):
        print ("Trainer provided... will be used")
        
    # steps=start_step+1
    # # +
    # added_steps=0+1
    #
    steps=start_step
    added_steps=0
    
    loss_total=0
    
    # ++ for pLM
    if pLM_Model_Name=='None':
        pLM_Model=None
        
    elif pLM_Model_Name=='esm2_t33_650M_UR50D':
        # dim: 1280
        esm_layer=33
        pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        len_toks=len(esm_alphabet.all_toks)
        pLM_Model.eval()
        pLM_Model. to(device)
        
    elif pLM_Model_Name=='esm2_t36_3B_UR50D':
        # dim: 2560
        esm_layer=36
        pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
        len_toks=len(esm_alphabet.all_toks)
        pLM_Model.eval()
        pLM_Model. to(device)
        
    elif pLM_Model_Name=='esm2_t30_150M_UR50D':
        # dim: 640
        esm_layer=30
        pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D()
        len_toks=len(esm_alphabet.all_toks)
        pLM_Model.eval()
        pLM_Model. to(device)
    
    elif pLM_Model_Name=='esm2_t12_35M_UR50D':
        # dim: 480
        esm_layer=12
        pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D()
        len_toks=len(esm_alphabet.all_toks)
        pLM_Model.eval()
        pLM_Model. to(device)
        
    else:
        print("pLM model is missing...")

        
    for e in range(1, epochs+1):
            
        # start = time.time()

        torch.cuda.empty_cache()
        print ("######################################################################################")
        start = time.time()
        print ("NOW: Training epoch: ", e+start_ep)

        # TRAINING
        train_epoch_loss = 0
        model.train()

        print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)")

        for item  in train_loader:
            steps += 1
            added_steps += 1

            X_train_batch= item[0].to(device)
            y_train_batch= item[1].to(device)
            # project y_ into embedding space
            if CKeys["Debug_TrainerPack"]==1:
                print("Initial unload the dataloader items: ...")
                print("X_train_batch.dim: ", X_train_batch.shape)
                print("y_train_batch.dim: ", y_train_batch.shape)
                
            # ---------------------------------------------------------
            # prepare for model.forward() to calculate loss
            # ---------------------------------------------------------
            # # --
            # if pLM_Model_Name=='None':
            #     # just use the encoded sequence
            #     y_train_batch_in = y_train_batch.unsqueeze(1)
            #     X_train_batch_in = X_train_batch.unsqueeze(1)
            #     # pass
            # elif pLM_Model_Name=='esm2_t33_650M_UR50D':
            #     with torch.no_grad():
            #         results = pLM_Model(
            #             y_train_batch,
            #             repr_layers=[33],
            #             return_contacts=False,
            #         )
            #     y_train_batch_in = results["representations"][33] # (batch, seq_len, channels)
            #     y_train_batch_in = rearrange(
            #         y_train_batch_in, 
            #         'b l c -> b c l'
            #     )
            #     X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1)
            # else:
            #     print(f"Required pLM name is not defined!!")
            # ++
            if pLM_Model_Name=='None':
                # just use the encoded sequence
                y_train_batch_in = y_train_batch.unsqueeze(1)
                X_train_batch_in = X_train_batch.unsqueeze(1)
                # pass
            else: # assume ESM models
                with torch.no_grad():
                    results = pLM_Model(
                        y_train_batch,
                        repr_layers=[esm_layer],
                        return_contacts=False,
                    )
                    y_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels)
                # ++ for Probability case
                if image_channels==33:
                    with torch.no_grad():
                        # calculate logits: (batch, seq_len, 33)
                        y_train_batch_in = pLM_Model.lm_head(
                            y_train_batch_in
                        )
                        # normalize to get (0,1) probability
                        y_train_batch_in = cal_norm_prob(y_train_batch_in)
                
                # switch the dimension -> (batch, channel, seq_len)
                y_train_batch_in = rearrange(
                    y_train_batch_in, 
                    'b l c -> b c l'
                )
                
                    
                X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1)

                
            # + for debug
            if CKeys["Debug_TrainerPack"]==1:
                print("After pLM model, the shape of X and y for training:")
                print("X_train_batch_in.dim: ", X_train_batch_in.shape)
                print("y_train_batch_in.dim: ", y_train_batch_in.shape)
                
                
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            if exists (trainer):
                
                if cond_image==False:
                    loss = trainer(
                        y_train_batch.unsqueeze(1) , # true image (batch, channels, seq_len)
                        x=X_train_batch,             # tokenized text (batch, )
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size,    # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                    )
                # # ----------------------------------------------------------
                # if cond_image==True:
                #     loss = trainer(
                #         y_train_batch.unsqueeze(1) ,            # true image
                #         x=None,                                 # tokenized text
                #         cond_images=X_train_batch.unsqueeze(1), # cond_image
                #         unet_number=train_unet_number,
                #         max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                #         )
                # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                if cond_image==True:
                    loss = trainer(
                        y_train_batch_in,                          # true image
                        x=None,                                 # tokenized text
                        cond_images=X_train_batch_in,              # cond_image
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                        )
                    
                trainer.update(unet_number = train_unet_number)

            else:
                optimizer.zero_grad()
                if cond_image==False:
                    loss=model (
                        y_train_batch.unsqueeze(1) , 
                        x=X_train_batch, 
                        unet_number=train_unet_number
                    )
                # # ------------------------------------------------------
                # if cond_image==True:
                #     loss=model (
                #         y_train_batch.unsqueeze(1) ,
                #         x=None, 
                #         cond_images=X_train_batch.unsqueeze(1), 
                #         unet_number=train_unet_number
                #     )
                # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
                if cond_image==True:
                    loss=model (
                        y_train_batch_in ,
                        x=None, 
                        cond_images=X_train_batch_in, 
                        unet_number=train_unet_number
                    )
                #
                loss.backward( )
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

            loss_total=loss_total+loss.item()
            # +
            train_epoch_loss=train_epoch_loss+loss.item()

            if steps % print_every == 0:
                # for progress bar
                print(".", end="")

            # if steps>0:
            if added_steps>0:
                
                if steps % print_loss_every_steps == 0:
                    # + for debug
                    if CKeys['Debug_TrainerPack']==2:
                        print('I am here')
                        print("Here is steps: ", steps)
                    
                    norm_loss=loss_total/print_loss_every_steps
                    print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}")
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # add a line to the hist file
                    add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n'
                    with open(train_hist_file,'a') as f:
                        f.write(add_line)


                    loss_list.append (norm_loss)
                    loss_total=0
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    epoch_list.append(e+start_ep)
                    
                    # fig = plt.figure(figsize=(12,8),dpi=200)
                    fig = plt.figure()
                    plt.plot (epoch_list, loss_list, label='Loss')
                    plt.legend()

                    # outname = prefix+ f"loss_{e+start_ep}_{steps}.jpg"
                    outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg"
                    # 
                    # the order, save then show, matters
                    if CKeys['SlientRun']==1:
                        plt.savefig(outname, dpi=200)
                    else:
                        plt.show()
                    plt.close(fig)
                    # plt.close()
                    
            if added_steps>0:
                # if steps>0:
                if steps % sample_steps == 0:
                    # + for debug
                    if CKeys['Debug_TrainerPack']==2:
                        print('I am here')
                        print("Here is steps: ", steps)
                    
                    if plot_unscaled:
                        #test before scaling...
                        plt.plot (
                            y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),
                            label= 'Unscaled GT'
                        )
                        plt.legend()
                        plt.show()

#                     # --------------------------------------------------
#                     GT=y_train_batch.cpu().detach() 

#                     GT=resize_image_to(
#                         GT.unsqueeze(1),
#                         model.imagen.image_sizes[train_unet_number-1],
#                     )

                    
                    ####
                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    print ("I. SAMPLING IN TEST SET: ")
                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    ####
                    # num_samples = min (num_samples,y_train_batch.shape[0] )
                    print (f"Producing {num_samples} samples...")

                    
                    if cond_image == True:
                        use_text_embedd=False
                        # -
                        # cond_scales_extended=[1. for i in range(num_samples)]
                        # +
                        cond_scales_extended=cond_scales
                    else:
                        use_text_embedd=True
                        cond_scales_extended=cond_scales

                    # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    sample_loop_omegafold_pLM_ModelB (
                        model,
                        test_loader,
                        cond_scales=cond_scales_extended, # cond_scales,# #list of cond scales - each sampled...
                        num_samples=num_samples, #how many samples produced every time tested.....
                        timesteps=64,
                        flag=steps,
                        #reverse=False,
                        foldproteins=foldproteins,
                        use_text_embedd= use_text_embedd,
                        # ++++++++++++++++++++
                        train_unet_number=train_unet_number,
                        ynormfac=ynormfac,
                        prefix=prefix,
                        tokenizer_y=tokenizer_y,
                        Xnormfac=Xnormfac,
                        tokenizer_X=tokenizer_X,
                        # ++
                        # ++
                        CKeys=CKeys,
                        sample_dir=sample_dir,
                        steps=steps,
                        e=e+start_ep,
                        IF_showfig= CKeys['SlientRun']!=1,
                        # ++
                        pLM_Model=pLM_Model,
                        pLM_Model_Name=pLM_Model_Name,
                        image_channels=image_channels,
                        pLM_alphabet=esm_alphabet,
                    )   
                    
                    
                    #---------------------------------------------------------
                    # sample_loop (
                    #     model,
                    #     test_loader,
                    #     cond_scales=cond_scales,# #list of cond scales - each sampled...
                    #     num_samples=num_samples, #how many samples produced every time tested.....
                    #     timesteps=64,
                    #     flag=steps,
                    #     #reverse=False,
                    #     foldproteins=foldproteins,
                    #     use_text_embedd= use_text_embedd,
                    #     # ++++++++++++++++++++
                    #     train_unet_number=train_unet_number,
                    #     ynormfac=ynormfac,
                    #     prefix=prefix,
                    #     tokenizer_y=tokenizer_y,
                    #     Xnormfac=Xnormfac,
                    #     tokenizer_X=tokenizer_X,
                    # )   

                    #index_word': '{"1": "~", "2": "h", "3": "e", "4": "s", "5": "t", "6": "g", "7": "b", "8": "i"}', 
                    #'word_index': '{"~": 1, "h": 2, "e": 3, "s": 4, "t": 5, "g": 6, "b": 7, "i": 8}'}

                    AH_code=2/Xnormfac
                    BS_code=3/Xnormfac
                    unstr_code= 1/Xnormfac

                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    print ("II. SAMPLING FOR DE NOVO:")
                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    
                    # +++++++++++++++++++++++++++++++++++++++++
                    DeNovoSam_pdbs, fasta_file_list=\
                    sample_sequence_omegafold_pLM_ModelB (
                        model,
                        x_data=test_condition_list,
                        flag=steps, # flag="DeNovo", # ,
                        cond_scales=1.,
                        foldproteins=foldproteins,
                        # ++++++++++
                        ynormfac=ynormfac,
                        train_unet_number=train_unet_number,
                        tokenizer_X=tokenizer_X,
                        Xnormfac=Xnormfac,
                        max_length=max_length,
                        prefix=prefix,
                        tokenizer_y=tokenizer_y,
                        # ++
                        CKeys=CKeys,
                        sample_dir=sample_dir,
                        steps=steps,
                        e=e+start_ep,
                        IF_showfig= CKeys['SlientRun']!=1,
                        # ++
                        pLM_Model=pLM_Model,
                        pLM_Model_Name=pLM_Model_Name,
                        image_channels=image_channels,
                        pLM_alphabet=esm_alphabet,
                       )
                    
                    if print_error and len(DeNovoSam_pdbs)>0:
                        print("Calculate SecStr and design error:")
                        #
                        for ii in range(len(test_condition_list)):
                            seq=test_condition_list[ii][0]
                            DSSPresult,_,sequence_res=get_DSSP_result(DeNovoSam_pdbs[ii]) 
                            print (f"INPUT:        {seq}\nRESULT:       {DSSPresult}\nAA sequence:  {sequence_res}")
                            error=string_diff (DSSPresult, seq)/len (seq)
                            print ("Error: ", error)
                        
                    
                    
                    # # +++++++++++++++++++++++++++++++++++++++++
                    # sample_sequence_omegafold_ModelB (
                    #     model,
                    #     x_data=test_condition_list,
                    #     flag=steps, # flag="DeNovo", # ,
                    #     cond_scales=1.,
                    #     foldproteins=foldproteins,
                    #     # ++++++++++
                    #     ynormfac=ynormfac,
                    #     train_unet_number=train_unet_number,
                    #     tokenizer_X=tokenizer_X,
                    #     Xnormfac=Xnormfac,
                    #     max_length=max_length,
                    #     prefix=prefix,
                    #     tokenizer_y=tokenizer_y,
                    #     # ++
                    #     CKeys=CKeys,
                    #     sample_dir=sample_dir,
                    #     steps=steps,
                    #     e=e+start_ep,
                    #     IF_showfig= CKeys['SlientRun']!=1,
                    #    )
                        
                    # for this_x_data in test_condition_list:
                    #     sample_sequence_omegafold (
                    #         model,
                    #         x_data=this_x_data,
                    #         flag=steps, # flag="DeNovo", # ,
                    #         cond_scales=1.,
                    #         foldproteins=True,
                    #         # ++++++++++
                    #         ynormfac=ynormfac,
                    #         train_unet_number=train_unet_number,
                    #         tokenizer_X=tokenizer_X,
                    #         Xnormfac=Xnormfac,
                    #         max_length=max_length,
                    #         prefix=prefix,
                    #         tokenizer_y=tokenizer_y,
                    #         # ++
                    #         CKeys=CKeys,
                    #         sample_dir=sample_dir,
                    #         steps=steps,
                    #         e=e+start_ep,
                    #         IF_showfig= CKeys['SlientRun']!=1,
                    #        )
                    #
    # model,
    # X=None, #this is the target conventionally when using text embd
    # flag=0,
    # cond_scales=1.,
    # foldproteins=False,
    # X_string=None,
    # x_data=None,  
    # skip_steps=0,
    # inpaint_images = None,
    # inpaint_masks = None,
    # inpaint_resample_times = None,
    # init_images = None,
    # num_cycle=16,
    # # ++++++++++++++++++++++++
    # ynormfac=1,
    # train_unet_number=1,
    # tokenizer_X=None,
    # Xnormfac=1.,
    # max_length=1.,
    # prefix=None,
    # tokenizer_y=None,
    # # ++
    # CKeys=None,
    # sample_dir=None,
                    
                    # -----------------------------------------    
                    # sample_sequence (
                    #     model,
                    #     x_data=['~~~HHHHHHHHHHHHHHH~~'],
                    #     flag=steps,cond_scales=1.,
                    #     foldproteins=True,
                    #     # ++++++++++
                    #     ynormfac=ynormfac,
                    #    )
                    # sample_sequence (
                    #     model,
                    #     x_data=['~~~HHHHHHHHHHHHHHH~~~~HHHHHHHHHHHHHH~~~'],
                    #     flag=steps,cond_scales=1.,
                    #     foldproteins=True,
                    #     # ++++++++++
                    #     ynormfac=ynormfac,
                    #    )
                    # sample_sequence (
                    #     model,
                    #     x_data=['~~EEESSTTS~SEEEEEEEEE~SBS~EEEEEE~~'],
                    #     flag=steps,cond_scales=1.,
                    #     foldproteins=True,
                    #     # ++++++++++++
                    #     ynormfac=ynormfac,
                    #    )

            # if steps>0:
            # # --------------------------------------------------------------------
            # if added_steps>0:
            #     if save_model and steps % print_loss_every_steps==0:
            #         fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt"
            #         trainer.save(fname)
            #         print (f"Model saved: ", fname)
            #         fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt"
            #         torch.save(model.state_dict(), fname)
            #         print (f"Statedict model saved: ", fname)
            
            # steps=steps+1
            # added_steps += 1
            
        # every epoch:
        norm_loss_over_e = train_epoch_loss/len(train_loader)
        print("\nnorm_loss over 1 epoch: ", norm_loss_over_e)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # write this into "train_hist_file_full"
        add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n'
        with open(train_hist_file_full,'a') as f:
            f.write(add_line)
            
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # save model every this epoches
        if save_model and (e+start_ep) % save_every_epoch==0 and e>1:
            # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt"
            fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt"
            trainer.save(fname)
            print (f"Model saved: ", fname)
            # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt"
            fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt"
            torch.save(model.state_dict(), fname)
            print (f"Statedict model saved: ", fname)
            # add a saving point file
            top_line='epoch,steps,norm_loss'+'\n'
            add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n'
            with open(save_point_info_file, "w") as f:
                f.write(top_line)
                f.write(add_line)
        


        print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------")
        
# ===============================================================
# for ProteinPredictor_B
#
def train_loop_Model_B_Predictor (
    model,
    train_loader,
    test_loader,
    #
    optimizer=None,
    print_every=10,
    epochs= 300,
    start_ep=0,
    start_step=0,
    train_unet_number=1,
    print_loss_every_steps=1000,
    #
    trainer=None,
    plot_unscaled=False,
    max_batch_size=4,
    save_model=False,
    cond_scales=[1.], #list of cond scales - each sampled...
    num_samples=2, #how many samples produced every time tested.....
    foldproteins=False,
    cond_image=False, #use cond_images...
    # add some
    # +++++++++++++++++++++++++++
    device=None,
    loss_list=[],
    epoch_list=[],
    train_hist_file=None,
    train_hist_file_full=None,
    prefix=None,
    Xnormfac=1.,
    ynormfac=1.,
    tokenizer_X=None,
    tokenizer_y=None,
    test_condition_list=[],
    max_length=1,
    CKeys=None,
    sample_steps=1,
    sample_dir=None,
    save_every_epoch=1,
    save_point_info_file=None,
    store_dir=None,
    # ++
    pLM_Model_Name=None,
    image_channels=None,
    print_error=False,
    # ++
    train_hist_file_on_testset=None,
):
    
    if not exists (trainer):
        if not exists (optimizer):
            print ("ERROR: If trainer not used, need to provide optimizer.")
    if exists (trainer):
        print ("Trainer provided... will be used")
        
    # steps=start_step+1
    # # +
    # added_steps=0+1
    #
    steps=start_step
    added_steps=0
    
    loss_total=0
    
    # ++ for pLM
#     # --
#     if pLM_Model_Name=='trivial':
#         pLM_Model=None
        
#     elif pLM_Model_Name=='esm2_t33_650M_UR50D':
#         # dim: 1280
#         esm_layer=33
#         pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
#         len_toks=len(esm_alphabet.all_toks)
#         pLM_Model.eval()
#         pLM_Model. to(device)
        
#     elif pLM_Model_Name=='esm2_t36_3B_UR50D':
#         # dim: 2560
#         esm_layer=36
#         pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
#         len_toks=len(esm_alphabet.all_toks)
#         pLM_Model.eval()
#         pLM_Model. to(device)
        
#     elif pLM_Model_Name=='esm2_t30_150M_UR50D':
#         # dim: 640
#         esm_layer=30
#         pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D()
#         len_toks=len(esm_alphabet.all_toks)
#         pLM_Model.eval()
#         pLM_Model. to(device)
    
#     elif pLM_Model_Name=='esm2_t12_35M_UR50D':
#         # dim: 480
#         esm_layer=12
#         pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D()
#         len_toks=len(esm_alphabet.all_toks)
#         pLM_Model.eval()
#         pLM_Model. to(device)
        
#     else:
#         print("pLM model is missing...")
    # ++
    pLM_Model, esm_alphabet, \
    esm_layer, len_toks = load_in_pLM(
        pLM_Model_Name,
        device,
    )

        
    for e in range(1, epochs+1):
            
        # start = time.time()

        torch.cuda.empty_cache()
        print ("######################################################################################")
        start = time.time()
        print ("NOW: Training epoch: ", e+start_ep)

        # TRAINING
        train_epoch_loss = 0
        model.train()

        print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)")

        for item  in train_loader:
            steps += 1
            added_steps += 1

            X_train_batch= item[0].to(device)
            y_train_batch= item[1].to(device)
            # project y_ into embedding space
            if CKeys["Debug_TrainerPack"]==1:
                print("Initial unload the dataloader items: ...")
                print("X_train_batch.dim: ", X_train_batch.shape)
                print("y_train_batch.dim: ", y_train_batch.shape)
                
            # ---------------------------------------------------------
            # prepare for model.forward() to calculate loss
            # ---------------------------------------------------------
            # # --
            # if pLM_Model_Name=='None':
            #     # just use the encoded sequence
            #     y_train_batch_in = y_train_batch.unsqueeze(1)
            #     X_train_batch_in = X_train_batch.unsqueeze(1)
            #     # pass
            # elif pLM_Model_Name=='esm2_t33_650M_UR50D':
            #     with torch.no_grad():
            #         results = pLM_Model(
            #             y_train_batch,
            #             repr_layers=[33],
            #             return_contacts=False,
            #         )
            #     y_train_batch_in = results["representations"][33] # (batch, seq_len, channels)
            #     y_train_batch_in = rearrange(
            #         y_train_batch_in, 
            #         'b l c -> b c l'
            #     )
            #     X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1)
            # else:
            #     print(f"Required pLM name is not defined!!")
            # ++
            if pLM_Model_Name=='trivial':
                # just use the encoded sequence
                y_train_batch_in = y_train_batch.unsqueeze(1)
                X_train_batch_in = X_train_batch.unsqueeze(1)
                # pass
            else: 
                # assume ESM models
                # --
                # # for ProteinDesigner
                # with torch.no_grad():
                #     results = pLM_Model(
                #         y_train_batch,
                #         repr_layers=[esm_layer],
                #         return_contacts=False,
                #     )
                # y_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels)
                # y_train_batch_in = rearrange(
                #     y_train_batch_in, 
                #     'b l c -> b c l'
                # )
                # X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1)
                #
                # ++
                # for ProteinPredictor
                with torch.no_grad():
                    results = pLM_Model(
                        X_train_batch,
                        repr_layers=[esm_layer],
                        return_contacts=False,
                    )
                X_train_batch_in = results["representations"][esm_layer] # (batch, seq_len, channels)
                X_train_batch_in = rearrange(
                    X_train_batch_in, 
                    'b l c -> b c l'
                )
                y_train_batch_in = y_train_batch.unsqueeze(1).repeat(1,image_channels,1)
                

                
            # + for debug
            if CKeys["Debug_TrainerPack"]==1:
                print("After pLM model, the shape of X and y for training:")
                print("X_train_batch_in.dim: ", X_train_batch_in.shape)
                print("y_train_batch_in.dim: ", y_train_batch_in.shape)
                
                
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            if exists (trainer):
                
                if cond_image==False:
                    loss = trainer(
                        y_train_batch.unsqueeze(1) , # true image (batch, channels, seq_len)
                        x=X_train_batch,             # tokenized text (batch, )
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size,    # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                    )
                # # ----------------------------------------------------------
                # if cond_image==True:
                #     loss = trainer(
                #         y_train_batch.unsqueeze(1) ,            # true image
                #         x=None,                                 # tokenized text
                #         cond_images=X_train_batch.unsqueeze(1), # cond_image
                #         unet_number=train_unet_number,
                #         max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                #         )
                # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                if cond_image==True:
                    loss = trainer(
                        y_train_batch_in,                          # true image
                        x=None,                                 # tokenized text
                        cond_images=X_train_batch_in,              # cond_image
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                        )
                    
                trainer.update(unet_number = train_unet_number)

            else:
                optimizer.zero_grad()
                if cond_image==False:
                    loss=model (
                        y_train_batch.unsqueeze(1) , 
                        x=X_train_batch, 
                        unet_number=train_unet_number
                    )
                # # ------------------------------------------------------
                # if cond_image==True:
                #     loss=model (
                #         y_train_batch.unsqueeze(1) ,
                #         x=None, 
                #         cond_images=X_train_batch.unsqueeze(1), 
                #         unet_number=train_unet_number
                #     )
                # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
                if cond_image==True:
                    loss=model (
                        y_train_batch_in ,
                        x=None, 
                        cond_images=X_train_batch_in, 
                        unet_number=train_unet_number
                    )
                #
                loss.backward( )
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

            loss_total=loss_total+loss.item()
            # +
            train_epoch_loss=train_epoch_loss+loss.item()

            if steps % print_every == 0:
                # for progress bar
                print(".", end="")

            # if steps>0:
            if added_steps>0:
                
                if steps % print_loss_every_steps == 0:
                    # + for debug
                    if CKeys['Debug_TrainerPack']==2:
                        print('I am here')
                        print("Here is steps: ", steps)
                    
                    norm_loss=loss_total/print_loss_every_steps
                    print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}")
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # add a line to the hist file
                    add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n'
                    with open(train_hist_file,'a') as f:
                        f.write(add_line)


                    loss_list.append (norm_loss)
                    loss_total=0
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    epoch_list.append(e+start_ep)
                    
                    # fig = plt.figure(figsize=(12,8),dpi=200)
                    fig = plt.figure()
                    plt.plot (epoch_list, loss_list, label='Loss')
                    plt.legend()

                    # outname = prefix+ f"loss_{e+start_ep}_{steps}.jpg"
                    outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg"
                    # 
                    # the order, save then show, matters
                    if CKeys['SlientRun']==1:
                        plt.savefig(outname, dpi=200)
                    else:
                        plt.show()
                    plt.close(fig)
                    # plt.close()
                    
            if added_steps>0:
                # if steps>0:
                if steps % sample_steps == 0:
                    # + for debug
                    if CKeys['Debug_TrainerPack']==2:
                        print('I am here')
                        print("Here is steps: ", steps)
                    
                    if plot_unscaled:
                        #test before scaling...
                        plt.plot (
                            y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),
                            label= 'Unscaled GT'
                        )
                        plt.legend()
                        plt.show()

#                     # --------------------------------------------------
#                     GT=y_train_batch.cpu().detach() 

#                     GT=resize_image_to(
#                         GT.unsqueeze(1),
#                         model.imagen.image_sizes[train_unet_number-1],
#                     )

                    
                    ####
                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    print ("I. SAMPLING IN TEST SET: ")
                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    ####
                    # num_samples = min (num_samples,y_train_batch.shape[0] )
                    print (f"Producing {num_samples} samples...")

                    
                    if cond_image == True:
                        use_text_embedd=False
                        # -
                        # cond_scales_extended=[1. for i in range(num_samples)]
                        # +
                        cond_scales_extended=cond_scales
                    else:
                        use_text_embedd=True
                        cond_scales_extended=cond_scales

                    # # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # # For ProteinDesigner
                    # sample_loop_omegafold_pLM_ModelB (
                    #     model,
                    #     test_loader,
                    #     cond_scales=cond_scales_extended, # cond_scales,# #list of cond scales - each sampled...
                    #     num_samples=num_samples, #how many samples produced every time tested.....
                    #     timesteps=64,
                    #     flag=steps,
                    #     #reverse=False,
                    #     foldproteins=foldproteins,
                    #     use_text_embedd= use_text_embedd,
                    #     # ++++++++++++++++++++
                    #     train_unet_number=train_unet_number,
                    #     ynormfac=ynormfac,
                    #     prefix=prefix,
                    #     tokenizer_y=tokenizer_y,
                    #     Xnormfac=Xnormfac,
                    #     tokenizer_X=tokenizer_X,
                    #     # ++
                    #     # ++
                    #     CKeys=CKeys,
                    #     sample_dir=sample_dir,
                    #     steps=steps,
                    #     e=e+start_ep,
                    #     IF_showfig= CKeys['SlientRun']!=1,
                    #     # ++
                    #     pLM_Model=pLM_Model,
                    #     pLM_Model_Name=pLM_Model_Name,
                    #     image_channels=image_channels,
                    #     pLM_alphabet=esm_alphabet,
                    # )
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # For ProteinPredictor:
                    val_epoch_MSE_list, \
                    resu_pred, resu_grou = \
                    sample_loop_omegafold_pLM_ModelB_Predictor (
                        model,
                        test_loader,
                        cond_scales=[1.], # cond_scales_extended, # #list of cond scales - each sampled...
                        num_samples=num_samples, #how many samples produced every time tested.....
                        timesteps=64,
                        flag=steps,
                        #reverse=False,
                        foldproteins=foldproteins,
                        use_text_embedd= use_text_embedd,
                        # ++++++++++++++++++++
                        train_unet_number=train_unet_number,
                        ynormfac=ynormfac,
                        prefix=prefix,
                        tokenizer_y=tokenizer_y,
                        Xnormfac=Xnormfac,
                        tokenizer_X=tokenizer_X,
                        # ++
                        # ++
                        CKeys=CKeys,
                        sample_dir=sample_dir,
                        steps=steps,
                        e=e+start_ep,
                        IF_showfig= CKeys['SlientRun']!=1,
                        # ++
                        pLM_Model=pLM_Model,
                        pLM_Model_Name=pLM_Model_Name,
                        image_channels=image_channels,
                        pLM_alphabet=esm_alphabet,
                        # ++
                        esm_layer=esm_layer,
                    )
                    # record the ERROR on the test set
                    print(f"Epo {str(e+start_ep)}, on TestSet, MSE: {val_epoch_MSE_list[0]}")
                    # only write the 0th case of MSE
                    add_line = str(e+start_ep)+','+str(steps)+','+str(val_epoch_MSE_list[0])+'\n'
                    with open(train_hist_file_on_testset,'a') as f:
                        f.write(add_line)


                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    print ("II. SAMPLING FOR DE NOVO: NOT USED in predictor mode")
                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    
#                     # +++++++++++++++++++++++++++++++++++++++++
#                     DeNovoSam_pdbs, fasta_file_list=\
#                     sample_sequence_omegafold_pLM_ModelB (
#                         model,
#                         x_data=test_condition_list,
#                         flag=steps, # flag="DeNovo", # ,
#                         cond_scales=1.,
#                         foldproteins=foldproteins,
#                         # ++++++++++
#                         ynormfac=ynormfac,
#                         train_unet_number=train_unet_number,
#                         tokenizer_X=tokenizer_X,
#                         Xnormfac=Xnormfac,
#                         max_length=max_length,
#                         prefix=prefix,
#                         tokenizer_y=tokenizer_y,
#                         # ++
#                         CKeys=CKeys,
#                         sample_dir=sample_dir,
#                         steps=steps,
#                         e=e+start_ep,
#                         IF_showfig= CKeys['SlientRun']!=1,
#                         # ++
#                         pLM_Model=pLM_Model,
#                         pLM_Model_Name=pLM_Model_Name,
#                         image_channels=image_channels,
#                         pLM_alphabet=esm_alphabet,
#                        )
                    
#                     if print_error and len(DeNovoSam_pdbs)>0:
#                         print("Calculate SecStr and design error:")
#                         #
#                         for ii in range(len(test_condition_list)):
#                             seq=test_condition_list[ii][0]
#                             DSSPresult,_,sequence_res=get_DSSP_result(DeNovoSam_pdbs[ii]) 
#                             print (f"INPUT:        {seq}\nRESULT:       {DSSPresult}\nAA sequence:  {sequence_res}")
#                             error=string_diff (DSSPresult, seq)/len (seq)
#                             print ("Error: ", error)
                        
                    
            
        # every epoch:
        norm_loss_over_e = train_epoch_loss/len(train_loader)
        print("\nnorm_loss over 1 epoch: ", norm_loss_over_e)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # write this into "train_hist_file_full"
        add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n'
        with open(train_hist_file_full,'a') as f:
            f.write(add_line)
            
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # save model every this epoches
        if save_model and (e+start_ep) % save_every_epoch==0 and e>1:
            # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt"
            fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt"
            trainer.save(fname)
            print (f"Model saved: ", fname)
            # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt"
            fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt"
            torch.save(model.state_dict(), fname)
            print (f"Statedict model saved: ", fname)
            # add a saving point file
            top_line='epoch,steps,norm_loss'+'\n'
            add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n'
            with open(save_point_info_file, "w") as f:
                f.write(top_line)
                f.write(add_line)
        


        print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------")
            
            
def train_loop_Old_FromModelB (model,
                train_loader,
                test_loader,
                optimizer=None,
                print_every=10,
                epochs= 300,
                start_ep=0,
                start_step=0,
                train_unet_number=1,
                print_loss=1000,
                trainer=None,
                plot_unscaled=False,
                max_batch_size=4,
                save_model=False,
                cond_scales=[7.5], #list of cond scales - each sampled...
                num_samples=2, #how many samples produced every time tested.....
                foldproteins=False,
                cond_image=False, #use cond_images...
                # add some
                # +++++++++++++++++++++++++++
                device=None,
                loss_list=[],
                prefix=None,
                ynormfac=1,
                test_condition_list=[],
                tokenizer_y=None,
                Xnormfac=1,
                tokenizer_X=None,
                max_length=1,
                
               ):
    
    if not exists (trainer):
        if not exists (optimizer):
            print ("ERROR: If trainer not used, need to provide optimizer.")
    if exists (trainer):
        print ("Trainer provided... will be used")
        
    steps=start_step
    
    loss_total=0
    for e in range(1, epochs+1):
        
        start = time.time()

        torch.cuda.empty_cache()
        print ("######################################################################################")
        start = time.time()
        print ("NOW: Training epoch: ", e+start_ep)

        train_epoch_loss = 0
        model.train()

        print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)")

        for item  in train_loader:

            X_train_batch= item[0].to(device)

            y_train_batch=item[1].to(device)

            if exists (trainer):
                if cond_image==False:
                    loss = trainer(
                        y_train_batch.unsqueeze(1) ,
                        x=X_train_batch,  
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size,    # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                    )
                if cond_image==True:

                    loss = trainer(
                        y_train_batch.unsqueeze(1) ,x=None,
                        cond_images=X_train_batch.unsqueeze(1), 
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                        )
                trainer.update(unet_number = train_unet_number)

            else:
                optimizer.zero_grad()
                if cond_image==False:
                    loss=model (y_train_batch.unsqueeze(1) , x=X_train_batch, unet_number=train_unet_number)
                if cond_image==True:
                    loss=model (y_train_batch.unsqueeze(1) ,x=None, cond_images=X_train_batch.unsqueeze(1), unet_number=train_unet_number)

                loss.backward( )

                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

                optimizer.step()

            loss_total=loss_total+loss.item()

            if steps % print_every == 0:
                print(".", end="")

            if steps>0:
                if steps % print_loss == 0:

                    if plot_unscaled:
                        #test before scaling...
                        plt.plot (y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),label= 'Unscaled GT')
                        plt.legend()
                        plt.show()


                    GT=y_train_batch.cpu().detach() 

                    GT=resize_image_to(
                        GT.unsqueeze(1),
                        model.imagen.image_sizes[train_unet_number-1],

                    )

                    norm_loss=loss_total/print_loss
                    print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}")

                    loss_list.append (norm_loss)
                    loss_total=0

                    plt.plot (loss_list, label='Loss')
                    plt.legend()

                    outname = prefix+ f"loss_{e}_{steps}.jpg"
                    plt.savefig(outname, dpi=200)
                    plt.show()

                    ####
                    num_samples = min (num_samples,y_train_batch.shape[0] )
                    print (f"Producing {num_samples} samples...")


                    if cond_image == True:
                        use_text_embedd=False
                    else:
                        use_text_embedd=True

                    sample_loop_FromModelB (
                        model,
                        test_loader,
                        cond_scales=cond_scales,# #list of cond scales - each sampled...
                        num_samples=num_samples, #how many samples produced every time tested.....
                        timesteps=64,
                        flag=steps,
                        #reverse=False,
                        foldproteins=foldproteins,
                        use_text_embedd= use_text_embedd,
                        # ++++++++++++++++++++
                        train_unet_number=train_unet_number,
                        ynormfac=ynormfac,
                        prefix=prefix,
                        tokenizer_y=tokenizer_y,
                        Xnormfac=Xnormfac,
                        tokenizer_X=tokenizer_X,
                    )   

                    #index_word': '{"1": "~", "2": "h", "3": "e", "4": "s", "5": "t", "6": "g", "7": "b", "8": "i"}', 
                    #'word_index': '{"~": 1, "h": 2, "e": 3, "s": 4, "t": 5, "g": 6, "b": 7, "i": 8}'}

                    AH_code=2/Xnormfac
                    BS_code=3/Xnormfac
                    unstr_code= 1/Xnormfac

                    print ("SAMPLING FOR DE NOVO:")
                    
                    # +++++++++++++++++++++++++++++++++++++++++
                    for this_x_data in test_condition_list:
                        sample_sequence_FromModelB (
                            model,
                            x_data=this_x_data,
                            flag=steps,cond_scales=1.,
                            foldproteins=True,
                            # ++++++++++
                            ynormfac=ynormfac,
                            train_unet_number=train_unet_number,
                            tokenizer_X=tokenizer_X,
                            Xnormfac=Xnormfac,
                            max_length=max_length,
                            prefix=prefix,
                            tokenizer_y=tokenizer_y,
                           )
                    # -----------------------------------------    
                    # sample_sequence (
                    #     model,
                    #     x_data=['~~~HHHHHHHHHHHHHHH~~'],
                    #     flag=steps,cond_scales=1.,
                    #     foldproteins=True,
                    #     # ++++++++++
                    #     ynormfac=ynormfac,
                    #    )
                    # sample_sequence (
                    #     model,
                    #     x_data=['~~~HHHHHHHHHHHHHHH~~~~HHHHHHHHHHHHHH~~~'],
                    #     flag=steps,cond_scales=1.,
                    #     foldproteins=True,
                    #     # ++++++++++
                    #     ynormfac=ynormfac,
                    #    )
                    # sample_sequence (
                    #     model,
                    #     x_data=['~~EEESSTTS~SEEEEEEEEE~SBS~EEEEEE~~'],
                    #     flag=steps,cond_scales=1.,
                    #     foldproteins=True,
                    #     # ++++++++++++
                    #     ynormfac=ynormfac,
                    #    )

            if steps>0:
                if save_model and steps % print_loss==0:
                    fname=f"{prefix}trainer_save-model-epoch_{e}.pt"
                    trainer.save(fname)
                    print (f"Model saved: ", fname)
                    fname=f"{prefix}statedict_save-model-epoch_{e}.pt"
                    torch.save(model.state_dict(), fname)
                    print (f"Statedict model saved: ", fname)

            steps=steps+1

        print (f"\n\n-------------------\nTime for epoch {e}={(time.time()-start)/60}\n-------------------")
            

# ++++++++++++++++++++++++++++++++++++++++++++++
def foldandsavePDB_pdb_fasta (
    sequence, 
    filename_out, 
    num_cycle=16, 
    flag=0,
    # ++++++++++++
    prefix=None,
):
    
    filename=f"{prefix}fasta_in_{flag}.fasta"
    print ("Writing FASTA file: ", filename)
    OUTFILE=f"{filename_out}_{flag}"
    with open (filename, mode ='w') as f:
        f.write (f'>{OUTFILE}\n')
        f.write (f'{sequence}')
        
    print (f"Now run OmegaFold.... on device={device}")    
    # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device
    cmd_line=F"omegafold {filename} {prefix} --num_cycle {num_cycle} --device={device}"
    print(os.popen(cmd_line).read())
    
    print ("Done OmegaFold")
    
    # PDB_result=f"{prefix}{OUTFILE}.PDB"
    PDB_result=f"{prefix}{OUTFILE}.pdb"
    print (f"Resulting PDB file...:  {PDB_result}")
    
    return PDB_result, filename



def foldandsavePDB (
    sequence, 
    filename_out, 
    num_cycle=16, 
    flag=0,
    # ++++++++++++
    prefix=None,
):
    
    filename=f"{prefix}fasta_in_{flag}.fasta"
    print ("Writing FASTA file: ", filename)
    OUTFILE=f"{filename_out}_{flag}"
    with open (filename, mode ='w') as f:
        f.write (f'>{OUTFILE}\n')
        f.write (f'{sequence}')
        
    print (f"Now run OmegaFold.... on device={device}")    
    # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device
    cmd_line=F"omegafold {filename} {prefix} --num_cycle {num_cycle} --device={device}"
    print(os.popen(cmd_line).read())
    
    print ("Done OmegaFold")
    
    # PDB_result=f"{prefix}{OUTFILE}.PDB"
    PDB_result=f"{prefix}{OUTFILE}.pdb"
    print (f"Resulting PDB file...:  {PDB_result}")
    
    return PDB_result

import py3Dmol
def plot_plddt_legend(dpi=100):
  thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt
color = "lDDT" # choose from ["chain", "lDDT", "rainbow"]
show_sidechains = False #choose from {type:"boolean"}
show_mainchains = False #choose from {type:"boolean"}

def show_pdb(pdb_file, flag=0,   show_sidechains=False, show_mainchains=False, color="lDDT"):
  model_name = f"Flag_{flag}"
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(pdb_file,'r').read(),'pdb')

  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = len(queries[0][1]) + 1 if is_complex else 1
    for n,chain,color in zip(range(chains),list("ABCDEFGH"),
                     ["lime","cyan","magenta","yellow","salmon","white","blue","orange"]):
      view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  if color == "lDDT":
      plot_plddt_legend().show() 
  return view

def get_avg_Bfac (file='./output_v3/[0.0 0.5 0.0 0.0 0.0 0.0 0.0 0.0].pdb'):
    p = PDBParser()
    avg_B=0
    bfac_list=[]
    
    structure = p.get_structure("X", file)
    for PDBmodel in structure:
        for chain in PDBmodel:
             for residue in chain:
                     for atom in residue:
                       
                        Bfac=atom.get_bfactor()
                        bfac_list.append(Bfac)
                        avg_B=avg_B+Bfac
                       
    avg_B=avg_B/len (bfac_list)
    print (f"For {file}, average B-factor={avg_B}")
    plt.plot (bfac_list, label='lDDT')
    plt.xlabel ('Atom #'   )
    plt.ylabel ('iDDT')
    plt.legend()
    plt.show()
    return avg_B, bfac_list

def sample_sequence_normalized_Bfac (seccs=[0.3, 0.3, 0.1, 0., 0., 0., 0., 0. ]):
    sample_numbers=torch.tensor([seccs])
    sample_numbers=torch.nn.functional.normalize (sample_numbers, dim=1)
    sample_numbers=sample_numbers/torch.sum(sample_numbers)

    PDB=sample_sequence (model,
                    X=sample_numbers,
                     flag=0,cond_scales=1, foldproteins=True,
                   )

    avg,_ = get_avg_Bfac (file=PDB[0])

    return PDB, avg

# ======================================================
# blocks for Model A
# ======================================================
def train_loop_Old_FromModelA (
    model,
    train_loader,
    test_loader,
    #
    optimizer=None,
    print_every=1,
    epochs= 300,
    start_ep=0,
    start_step=0,
    train_unet_number=1,
    print_loss_every_steps=1000,
    #
    trainer=None,
    plot_unscaled=False,
    max_batch_size=4,
    save_model=False,
    cond_scales=[1.0], #list of cond scales
    num_samples=2, #how many samples produced every time tested.....
    foldproteins=False,
    # ++
    cond_image=False, # not use cond_images... for model A
    cond_text=True,   # use condi_text...      for model A
    # +
    device=None,
    loss_list=[],
    epoch_list=[],
    train_hist_file=None,
    train_hist_file_full=None,
    prefix=None, # not used in this function
    Xnormfac=None,
    ynormfac=1.,
    tokenizer_X=None,
    tokenizer_y=None,
    test_condition_list=[],
    max_length_Y=1,
    max_text_len_X=1,
    CKeys=None,
    sample_steps=1,
    sample_dir=None,
    save_every_epoch=1,
    save_point_info_file=None,
    store_dir=None,
):
    # #+
    # Xnormfac=Xnormfac.to(model.device)
    
    if not exists (trainer):
        if not exists (optimizer):
            print ("ERROR: If trainer not used, need to provide optimizer.")
    if exists (trainer):
        print ("Trainer provided... will be used")
    # --------------------------------
    # steps=start_step
    # ++++++++++++++++++++++++++++++++
    steps=start_step
    added_steps=0

    loss_total=0
    for e in range(1, epochs+1):
        # start = time.time()

        torch.cuda.empty_cache()
        print ("######################################################################################")
        start = time.time()
        print ("NOW: Training epoch: ", e+start_ep)

        # TRAINING
        train_epoch_loss = 0
        model.train()

        print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)")

        for item  in train_loader:
            # ++
            steps += 1
            added_steps += 1

            X_train_batch= item[0].to(device)
            y_train_batch=item[1].to(device)

            if exists (trainer):
                if cond_image==False:
                    # ========================================
                    # Model A: condition via text
                    # ========================================
                    # this block depends on the model:forward
                    loss = trainer(
                        # # --------------------------------
                        # X_train_batch, 
                        # y_train_batch.unsqueeze(1) ,
                        # ++++++++++++++++++++++++++++++++
                        y_train_batch.unsqueeze(1) ,
                        x=X_train_batch, 
                        # 
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size,    # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                    )
                if cond_image==True:
                    # ========================================
                    # Model B: condition via image/sequence
                    # ========================================
                    # added for future: Train_loop B
                    loss = trainer(
                        y_train_batch.unsqueeze(1) ,
                        x=None,
                        cond_images=X_train_batch.unsqueeze(1), 
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                    )
                    # pass
                #
                trainer.update(unet_number = train_unet_number)

            else:
                optimizer.zero_grad()
                if cond_image==False:
                    # this block depends on the model:forward
                    loss=model ( 
                        # # --------------------------------
                        # X_train_batch, 
                        # y_train_batch.unsqueeze(1) ,
                        # ++++++++++++++++++++++++++++++++
                        y_train_batch.unsqueeze(1) ,
                        x=X_train_batch,
                        #
                        unet_number=train_unet_number
                    )
                if cond_image==True:
                    # added for future: Train_loop B
                    loss=model (
                        y_train_batch.unsqueeze(1) ,
                        x=None, 
                        cond_images=X_train_batch.unsqueeze(1), 
                        unet_number=train_unet_number
                    )
                #
                loss.backward( )
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

            loss_total=loss_total+loss.item()
            # +
            train_epoch_loss=train_epoch_loss+loss.item()

            if steps % print_every == 0:
                # for progress bar
                print(".", end="")

            # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
            # record loss block
            # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
            # if steps>0:
            if added_steps>0:

                if steps % print_loss_every_steps == 0:
                    # + for debug
                    if CKeys['Debug_TrainerPack']==2:
                        print("Here is step: ", steps)

                    norm_loss=loss_total/print_loss_every_steps
                    print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}")
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # add a line to the hist file
                    add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n'
                    with open(train_hist_file,'a') as f:
                        f.write(add_line)

                    loss_list.append (norm_loss)
                    loss_total=0
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    epoch_list.append(e+start_ep)

                    fig = plt.figure()
                    plt.plot (epoch_list, loss_list, label='Loss')
                    plt.legend()
                    # outname = prefix+ f"loss_{e}_{steps}.jpg"
                    outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg"
                    # 
                    # the order, save then show, matters
                    if CKeys['SlientRun']==1:
                        plt.savefig(outname, dpi=200)
                    else:
                        plt.show()
                    plt.close(fig)

            # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
            # sample in test set block
            # set sample_steps < 0 to switch off this block
            # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
            # if steps>0:
            if added_steps>0:
                if steps % sample_steps == 0 and sample_steps > 0:
                    # + for debug
                    if CKeys['Debug_TrainerPack']==2:
                        print("Here is steps: ", steps)

                    if plot_unscaled:
                        # test before scaling...
                        plt.plot (
                            y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),
                            label= 'Unscaled GT'
                        )
                        plt.legend()
                        plt.show()

                    #rescale GT to properly plot
                    GT=y_train_batch.cpu().detach() 

                    GT=resize_image_to(
                        GT.unsqueeze(1),
                        model.imagen.image_sizes[train_unet_number-1],

                    )
                    ####
                    print ("I. SAMPLING IN TEST SET: ")
                    ####

                    num_samples = min (num_samples,y_train_batch.shape[0] )
                    print (f"Producing {num_samples} samples...")

                    sample_loop_omegafold_ModelA (
                        model,
                        test_loader,
                        cond_scales=cond_scales,
                        num_samples=num_samples, #how many samples produced every time tested.....
                        timesteps=None,
                        flag=e+start_ep, # steps,
                        foldproteins=foldproteins,
                        # add condi_key
                        cond_image=cond_image, # Not used for now
                        cond_text=cond_text,   # Not used for now
                        skip_steps=0,
                        #
                        max_text_len=max_text_len_X,
                        max_length=max_length_Y,
                        # ++++++++++++++++++++
                        train_unet_number=train_unet_number,
                        ynormfac=ynormfac,
                        prefix=prefix,   #
                        tokenizer_y=tokenizer_y,
                        Xnormfac_CondiText=Xnormfac,
                        tokenizer_X_CondiText=tokenizer_X,
                        # ++
                        CKeys=CKeys,
                        sample_dir=sample_dir,
                        steps=steps,
                        e=e+start_ep,
                        IF_showfig= CKeys['SlientRun']!=1 ,
                    )   

                    print ("II. SAMPLING FOR DE NOVO:")

                    sample_sequence_omegafold_ModelA (
                        # # ----------------------------------------------
                        # model,
                        # X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]],
                        # foldproteins=foldproteins,
                        # flag=steps,cond_scales=1.,
                        # ++++++++++++++++++++++++++++++++++++++++++++++
                        model,
                        X=test_condition_list, # [[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X
                        flag=e+start_ep, # steps, # 0,
                        cond_scales=cond_scales, # 1.,
                        foldproteins=True, # False,
                        X_string=None,                                # from text conditioning X_string
                        x_data=None,                                  # from image conditioning x_data   
                        skip_steps=0,
                        inpaint_images=None, # in formation Y data
                        inpaint_masks = None,
                        inpaint_resample_times = None,
                        init_images = None,
                        num_cycle=16,          # for omegafolding
                        calc_error=False,      # for check on folded results, not used for every case
                        # ++++++++++++++++++++++++++
                        # tokenizers
                        tokenizer_X_forImageCondi=None, # for x_data
                        Xnormfac_forImageCondi=1.,
                        tokenizer_X_forTextCondi=None,  # for X if NEEDED only
                        Xnormfac_forTextCondi=1.,
                        tokenizer_y=tokenizer_y, # None, # for output Y
                        ynormfac=ynormfac,
                        # length
                        train_unet_number=1,
                        max_length_Y=max_length_Y,                 # for Y, X_forImageCondi
                        max_text_len=max_text_len_X,                 # for    X_forTextCondi
                        # other info
                        steps=steps, # None,
                        e=e, # None,
                        sample_dir=sample_dir, # None,
                        prefix=prefix, # None,
                        IF_showfig= CKeys['SlientRun']!=1, # True,
                        CKeys=CKeys,
                        # TBA to Model B
                        normalize_X_cond_to_one=False,
                    )

                    # sample_sequence (model,
                    #     X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins,
                    #      flag=steps,cond_scales=1.,
                    #    )

        # summerize loss over every epoch:
        norm_loss_over_e = train_epoch_loss/len(train_loader)
        print("\nnorm_loss over 1 epoch: ", norm_loss_over_e)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # write this into "train_hist_file_full"
        add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n'
        with open(train_hist_file_full,'a') as f:
            f.write(add_line)
        
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # save model every this epoches
        if save_model and (e+start_ep) % save_every_epoch==0 and e>1:
            # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt"
            fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt"
            trainer.save(fname)
            print (f"Model saved: ", fname)
            # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt"
            fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt"
            torch.save(model.state_dict(), fname)
            print (f"Statedict model saved: ", fname)
            # add a saving point file
            top_line='epoch,steps,norm_loss'+'\n'
            add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n'
            with open(save_point_info_file, "w") as f:
                f.write(top_line)
                f.write(add_line)
                
            # if steps>0:
            #     if save_model and steps % print_loss_every_steps==0: 
            #         fname=f"{prefix}trainer_save-model-epoch_{e}.pt"
            #         trainer.save(fname)
            #         fname=f"{prefix}statedict_save-model-epoch_{e}.pt"
            #         torch.save(model.state_dict(), fname)
            #         print (f"Model saved: ")

            # steps=steps+1

        print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------")
        
def train_loop_ForModelA_II (
    model,
    train_loader,
    test_loader,
    #
    optimizer=None,
    print_every=1,
    epochs= 300,
    start_ep=0,
    start_step=0,
    train_unet_number=1,
    print_loss_every_steps=1000,
    #
    trainer=None,
    plot_unscaled=False,
    max_batch_size=4,
    save_model=False,
    cond_scales=[1.0], #list of cond scales
    num_samples=2, #how many samples produced every time tested.....
    foldproteins=False,
    # ++
    cond_image=False, # not use cond_images... for model A
    cond_text=True,   # use condi_text...      for model A
    # +
    device=None,
    loss_list=[],
    epoch_list=[],
    train_hist_file=None,
    train_hist_file_full=None,
    prefix=None, # not used in this function
    Xnormfac=None,
    ynormfac=1.,
    tokenizer_X=None,
    tokenizer_y=None,
    test_condition_list=[],
    max_length_Y=1,
    max_text_len_X=1,
    CKeys=None,
    sample_steps=1,
    sample_dir=None,
    save_every_epoch=1,
    save_point_info_file=None,
    store_dir=None,
    # ++ for pLM
    pLM_Model_Name=None,
    image_channels=None,
    print_error=False, # not defined for Problem6 # True,
):
    # #+
    # Xnormfac=Xnormfac.to(model.device)
    
    if not exists (trainer):
        if not exists (optimizer):
            print ("ERROR: If trainer not used, need to provide optimizer.")
    if exists (trainer):
        print ("Trainer provided... will be used")
    # --------------------------------
    # steps=start_step
    # ++++++++++++++++++++++++++++++++
    steps=start_step
    added_steps=0

    loss_total=0
    
    # ++ for pLM
    if pLM_Model_Name=='None':
        pLM_Model=None
        
    elif pLM_Model_Name=='esm2_t33_650M_UR50D':
        # dim: 1280
        esm_layer=33
        pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        len_toks=len(esm_alphabet.all_toks)
        pLM_Model.eval()
        pLM_Model. to(device)
        
    elif pLM_Model_Name=='esm2_t36_3B_UR50D':
        # dim: 2560
        esm_layer=36
        pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
        len_toks=len(esm_alphabet.all_toks)
        pLM_Model.eval()
        pLM_Model. to(device)
        
    elif pLM_Model_Name=='esm2_t30_150M_UR50D':
        # dim: 640
        esm_layer=30
        pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D()
        len_toks=len(esm_alphabet.all_toks)
        pLM_Model.eval()
        pLM_Model. to(device)
    
    elif pLM_Model_Name=='esm2_t12_35M_UR50D':
        # dim: 480
        esm_layer=12
        pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D()
        len_toks=len(esm_alphabet.all_toks)
        pLM_Model.eval()
        pLM_Model. to(device)
        
    else:
        print("pLM model is missing...")
        
        
    for e in range(1, epochs+1):
        # start = time.time()

        torch.cuda.empty_cache()
        print ("######################################################################################")
        start = time.time()
        print ("NOW: Training epoch: ", e+start_ep)

        # TRAINING
        train_epoch_loss = 0
        model.train()

        print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)")

        for item  in train_loader:
            # ++
            steps += 1
            added_steps += 1

            X_train_batch= item[0].to(device)
            y_train_batch=item[1].to(device)
            # project y_ into embedding space
            if CKeys["Debug_TrainerPack"]==1:
                print("Initial unload the dataloader items: ...")
                print(X_train_batch.shape)
                print(y_train_batch.shape)
            # ++
            # project the AA seq into embedding space
            # for output, it is shared between ModelA and ModelB
            # # --
            # if pLM_Model_Name=='None':
            #     # just use the encoded sequence
            #     y_train_batch_in = y_train_batch.unsqueeze(1)
            #     # pass
            # elif pLM_Model_Name=='esm2_t33_650M_UR50D':
            #     with torch.no_grad():
            #         results = pLM_Model(
            #             y_train_batch,
            #             repr_layers=[33],
            #             return_contacts=False,
            #         )
            #     y_train_batch_in = results["representations"][33]
            #     y_train_batch_in = rearrange(
            #         y_train_batch_in, 
            #         'b l c -> b c l'
            #     )
            # else:
            #     print(f"Required pLM name is not defined!!")
            # ++
            if pLM_Model_Name=='None':
                # just use the encoded sequence
                y_train_batch_in = y_train_batch.unsqueeze(1)
                # pass
            else: # for ESM models # pLM_Model_Name=='esm2_t33_650M_UR50D':
                with torch.no_grad():
                    results = pLM_Model(
                        y_train_batch,
                        repr_layers=[esm_layer],
                        return_contacts=False,
                    )
                y_train_batch_in = results["representations"][esm_layer]
                y_train_batch_in = rearrange(
                    y_train_batch_in, 
                    'b l c -> b c l'
                )
            
                
            #
            # For input part, this block is different for ModelA and ModelB
            if cond_image==False:
                # model A: X: text_condi, not affected by pLM
                X_train_batch_in = X_train_batch
            else:
                # model B: X: cond_img, will be affected by pLM
                X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1)
            #
            # + for debug
            if CKeys["Debug_TrainerPack"]==1:
                print("After pLM model, the shape of X and y for training:")
                print("X_train_batch_in.dim: ", X_train_batch_in.shape)
                print("y_train_batch_in.dim: ", y_train_batch_in.shape)
                    
            
            

            if exists (trainer):
                if cond_image==False:
                    # ========================================
                    # Model A: condition via text
                    # ========================================
                    # this block depends on the model:forward
                    loss = trainer(
                        # # --------------------------------
                        # X_train_batch, 
                        # y_train_batch.unsqueeze(1) ,
                        # # ++++++++++++++++++++++++++++++++
                        # y_train_batch.unsqueeze(1) ,
                        # x=X_train_batch, 
                        # ++ pLM
                        y_train_batch_in,
                        x=X_train_batch_in,
                        # 
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size,    # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                    )
                if cond_image==True:
                    # ========================================
                    # Model B: condition via image/sequence
                    # ========================================
                    # # --
                    # # added for future: Train_loop B
                    # loss = trainer(
                    #     y_train_batch.unsqueeze(1) ,
                    #     x=None,
                    #     cond_images=X_train_batch.unsqueeze(1), 
                    #     unet_number=train_unet_number,
                    #     max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                    # )
                    # ++ from pLM+ModelB
                    loss = trainer(
                        y_train_batch_in,                          # true image
                        x=None,                                 # tokenized text
                        cond_images=X_train_batch_in,              # cond_image
                        unet_number=train_unet_number,
                        max_batch_size = max_batch_size, # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                        )
                    # pass
                #
                trainer.update(unet_number = train_unet_number)

            else:
                optimizer.zero_grad()
                if cond_image==False:
                    # this block depends on the model:forward
                    loss=model ( 
                        # # --------------------------------
                        # X_train_batch, 
                        # y_train_batch.unsqueeze(1) ,
                        # # ++++++++++++++++++++++++++++++++
                        # y_train_batch.unsqueeze(1) ,
                        # x=X_train_batch,
                        # ++ pLM
                        y_train_batch_in,
                        x=X_train_batch_in,
                        #
                        unet_number=train_unet_number
                    )
                if cond_image==True:
                    # added for future: Train_loop B
                    # # --
                    # loss=model (
                    #     y_train_batch.unsqueeze(1) ,
                    #     x=None, 
                    #     cond_images=X_train_batch.unsqueeze(1), 
                    #     unet_number=train_unet_number
                    # )
                    # ++ from pLM
                    loss=model (
                        y_train_batch_in ,
                        x=None, 
                        cond_images=X_train_batch_in, 
                        unet_number=train_unet_number
                    )
                #
                loss.backward( )
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

            loss_total=loss_total+loss.item()
            # +
            train_epoch_loss=train_epoch_loss+loss.item()

            if steps % print_every == 0:
                # for progress bar
                print(".", end="")

            # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
            # record loss block
            # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
            # if steps>0:
            if added_steps>0:

                if steps % print_loss_every_steps == 0:
                    # + for debug
                    if CKeys['Debug_TrainerPack']==2:
                        print("Here is step: ", steps)

                    norm_loss=loss_total/print_loss_every_steps
                    print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}")
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # add a line to the hist file
                    add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n'
                    with open(train_hist_file,'a') as f:
                        f.write(add_line)

                    loss_list.append (norm_loss)
                    loss_total=0
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    epoch_list.append(e+start_ep)

                    fig = plt.figure()
                    plt.plot (epoch_list, loss_list, label='Loss')
                    plt.legend()
                    # outname = prefix+ f"loss_{e}_{steps}.jpg"
                    outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg"
                    # 
                    # the order, save then show, matters
                    if CKeys['SlientRun']==1:
                        plt.savefig(outname, dpi=200)
                    else:
                        plt.show()
                    plt.close(fig)

            # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
            # sample in test set block
            # set sample_steps < 0 to switch off this block
            # \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
            # if steps>0:
            if added_steps>0:
                if steps % sample_steps == 0 and sample_steps > 0:
                    # + for debug
                    if CKeys['Debug_TrainerPack']==2:
                        print("Here is steps: ", steps)

                    if plot_unscaled:
                        # test before scaling...
                        plt.plot (
                            y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),
                            label= 'Unscaled GT'
                        )
                        plt.legend()
                        plt.show()

#                     # -- look like not used
#                     #rescale GT to properly plot
#                     GT=y_train_batch.cpu().detach() 

#                     GT=resize_image_to(
#                         GT.unsqueeze(1),
#                         model.imagen.image_sizes[train_unet_number-1],

#                     )
                    ####
                    print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ")
                    print ("I. SAMPLING IN TEST SET: ")
                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")
                    ####

                    num_samples = min (num_samples,y_train_batch.shape[0] )
                    print (f"Producing {num_samples} samples...")

                    sample_loop_omegafold_pLM_ModelA (
                        model,
                        test_loader,
                        cond_scales=cond_scales,
                        num_samples=num_samples, #how many samples produced every time tested.....
                        timesteps=None,
                        flag=e+start_ep, # steps,
                        foldproteins=foldproteins,
                        # add condi_key
                        cond_image=cond_image, # Not used for now
                        cond_text=cond_text,   # Not used for now
                        skip_steps=0,
                        #
                        max_text_len=max_text_len_X,
                        max_length=max_length_Y,
                        # ++++++++++++++++++++
                        train_unet_number=train_unet_number,
                        ynormfac=ynormfac,
                        prefix=prefix,   #
                        tokenizer_y=tokenizer_y,
                        Xnormfac_CondiText=Xnormfac,
                        tokenizer_X_CondiText=tokenizer_X,
                        # ++
                        CKeys=CKeys,
                        sample_dir=sample_dir,
                        steps=steps,
                        e=e+start_ep,
                        IF_showfig= CKeys['SlientRun']!=1 ,
                        # ++ for pLM
                        pLM_Model=pLM_Model,
                        pLM_Model_Name=pLM_Model_Name,
                        image_channels=image_channels,
                        pLM_alphabet=esm_alphabet,
                    )   

                    print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ")
                    print ("II. SAMPLING FOR DE NOVO:")
                    print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ")

                    DeNovoSam_pdbs, fasta_file_list=\
                    sample_sequence_omegafold_pLM_ModelA (
                        # # ----------------------------------------------
                        # model,
                        # X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]],
                        # foldproteins=foldproteins,
                        # flag=steps,cond_scales=1.,
                        # ++++++++++++++++++++++++++++++++++++++++++++++
                        model,
                        X=test_condition_list, # [[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X
                        flag=e+start_ep, # steps, # 0,
                        cond_scales=cond_scales, # 1.,
                        foldproteins=True, # False,
                        X_string=None,                                # from text conditioning X_string
                        x_data=None,                                  # from image conditioning x_data   
                        skip_steps=0,
                        inpaint_images=None, # in formation Y data
                        inpaint_masks = None,
                        inpaint_resample_times = None,
                        init_images = None,
                        num_cycle=16,          # for omegafolding
                        calc_error=False,      # for check on folded results, not used for every case
                        # ++++++++++++++++++++++++++
                        # tokenizers
                        tokenizer_X_forImageCondi=None, # for x_data
                        Xnormfac_forImageCondi=1.,
                        tokenizer_X_forTextCondi=None,  # for X if NEEDED only
                        Xnormfac_forTextCondi=1.,
                        tokenizer_y=tokenizer_y, # None, # for output Y
                        ynormfac=ynormfac,
                        # length
                        train_unet_number=1,
                        max_length_Y=max_length_Y,                 # for Y, X_forImageCondi
                        max_text_len=max_text_len_X,                 # for    X_forTextCondi
                        # other info
                        steps=steps, # None,
                        e=e, # None,
                        sample_dir=sample_dir, # None,
                        prefix=prefix, # None,
                        IF_showfig= CKeys['SlientRun']!=1, # True,
                        CKeys=CKeys,
                        # TBA to Model B
                        normalize_X_cond_to_one=False,
                        # ++ for pLM
                        pLM_Model=pLM_Model,
                        pLM_Model_Name=pLM_Model_Name,
                        image_channels=image_channels,
                        pLM_alphabet=esm_alphabet,
                    )

                    # sample_sequence (model,
                    #     X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins,
                    #      flag=steps,cond_scales=1.,
                    #    )

        # summerize loss over every epoch:
        norm_loss_over_e = train_epoch_loss/len(train_loader)
        print("\nnorm_loss over 1 epoch: ", norm_loss_over_e)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # write this into "train_hist_file_full"
        add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n'
        with open(train_hist_file_full,'a') as f:
            f.write(add_line)
        
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # save model every this epoches
        if save_model and (e+start_ep) % save_every_epoch==0 and e>1:
            # fname=f"{prefix}trainer_save-model-epoch_{e+start_ep}.pt"
            fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt"
            trainer.save(fname)
            print (f"Model saved: ", fname)
            # fname=f"{prefix}statedict_save-model-epoch_{e+start_ep}.pt"
            fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt"
            torch.save(model.state_dict(), fname)
            print (f"Statedict model saved: ", fname)
            # add a saving point file
            top_line='epoch,steps,norm_loss'+'\n'
            add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n'
            with open(save_point_info_file, "w") as f:
                f.write(top_line)
                f.write(add_line)
                
            # if steps>0:
            #     if save_model and steps % print_loss_every_steps==0: 
            #         fname=f"{prefix}trainer_save-model-epoch_{e}.pt"
            #         trainer.save(fname)
            #         fname=f"{prefix}statedict_save-model-epoch_{e}.pt"
            #         torch.save(model.state_dict(), fname)
            #         print (f"Model saved: ")

            # steps=steps+1

        print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------")

# from original, not used any more
def train_loop_Model_A (
    model,
    train_loader,
    test_loader,
    optimizer=None,
    print_every=10,
    epochs= 300,
    start_ep=0,
    start_step=0,
    train_unet_number=1,
    print_loss=1000,
    trainer=None,
    plot_unscaled=False,
    max_batch_size=4,
    save_model=False,
    cond_scales=[1.0], #list of cond scales
    num_samples=2, #how many samples produced every time tested.....
    foldproteins=False,
):
    
    
    if not exists (trainer):
        if not exists (optimizer):
            print ("ERROR: If trainer not used, need to provide optimizer.")
    if exists (trainer):
        print ("Trainer provided... will be used")
    steps=start_step

    loss_total=0
    for e in range(1, epochs+1):
            start = time.time()

            torch.cuda.empty_cache()
            print ("######################################################################################")
            start = time.time()
            print ("NOW: Training epoch: ", e+start_ep)

            # TRAINING
            train_epoch_loss = 0
            model.train()
            
            print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)")
            for item  in train_loader:
                X_train_batch= item[0].to(device)
                y_train_batch=item[1].to(device)

                if exists (trainer):
                    loss = trainer(
                            X_train_batch, y_train_batch.unsqueeze(1) ,
                            unet_number=train_unet_number,
                            max_batch_size = max_batch_size,    # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
                        )
                    trainer.update(unet_number = train_unet_number)

                else:
                    optimizer.zero_grad()
                    loss=model ( X_train_batch, y_train_batch.unsqueeze(1) ,unet_number=train_unet_number)
                    loss.backward( )
                   
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

                    optimizer.step()

                loss_total=loss_total+loss.item()
                
                if steps % print_every == 0:
                    print(".", end="")

                if steps>0:
                    if steps % print_loss == 0:

                        if plot_unscaled:
                            
                            plt.plot (y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),label= 'Unscaled GT')
                            plt.legend()
                            plt.show()
 
                        #rescale GT to properly plot
                        GT=y_train_batch.cpu().detach() 
                        
                        GT=resize_image_to(
                            GT.unsqueeze(1),
                            model.imagen.image_sizes[train_unet_number-1],

                        )
                        
                        norm_loss=loss_total/print_loss
                        print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}")

                        loss_list.append (norm_loss)
                        loss_total=0

                        plt.plot (loss_list, label='Loss')
                        plt.legend()

                        outname = prefix+ f"loss_{e}_{steps}.jpg"
                        plt.savefig(outname, dpi=200)
                        plt.show()
                        
                        num_samples = min (num_samples,y_train_batch.shape[0] )
                        print (f"Producing {num_samples} samples...")
                        
                        sample_loop (model,
                            test_loader,
                            cond_scales=cond_scales,
                            num_samples=1, #how many samples produced every time tested.....
                            timesteps=64,
                                    flag=steps,foldproteins=foldproteins,
                                    )   
                        
                        print ("SAMPLING FOR DE NOVO:")
                        sample_sequence (model,
                            X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]],foldproteins=foldproteins,
                             flag=steps,cond_scales=1.,
                           )
                        sample_sequence (model,
                            X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins,
                             flag=steps,cond_scales=1.,
                           )

                if steps>0:
                    if save_model and steps % print_loss==0: 
                        fname=f"{prefix}trainer_save-model-epoch_{e}.pt"
                        trainer.save(fname)
                        fname=f"{prefix}statedict_save-model-epoch_{e}.pt"
                        torch.save(model.state_dict(), fname)
                        print (f"Model saved: ")
                    
                steps=steps+1
                                         
            print (f"\n\n-------------------\nTime for epoch {e}={(time.time()-start)/60}\n-------------------")

# +++
def sample_sequence_omegafold_ModelA (
    model,
    X=[[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X
    flag=0,
    cond_scales=1.,
    foldproteins=False,
    X_string=None,                                # from text conditioning X_string
    x_data=None,                                  # from image conditioning x_data   
    skip_steps=0,
    inpaint_images=None, # in formation Y data
    inpaint_masks = None,
    inpaint_resample_times = None,
    init_images = None,
    num_cycle=16,          # for omegafolding
    calc_error=False,      # for check on folded results, not used for every case
    # ++++++++++++++++++++++++++
    # tokenizers
    tokenizer_X_forImageCondi=None, # for x_data
    Xnormfac_forImageCondi=1.,
    tokenizer_X_forTextCondi=None,  # for X if NEEDED only
    Xnormfac_forTextCondi=1.,
    tokenizer_y=None,               # for output Y
    ynormfac=1,
    # length
    train_unet_number=1,
    max_length_Y=1,                 # for Y, X_forImageCondi
    max_text_len=1,                 # for    X_forTextCondi
    # other info
    steps=None,
    e=None,
    sample_dir=None,
    prefix=None,
    IF_showfig=True,
    CKeys=None,
    # TBA to Model B
    normalize_X_cond_to_one=False,
):
    # -----------
    # steps=0
    # e=flag

    # --
    # print (f"Producing {len(X)} samples...")
    # ++
    if X!=None:
        print (f"Producing {len(X)} samples...from text conditioning X...")
        lenn_val=len(X)
    if X_string!=None:
        lenn_val=len(X_string)
        print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...")
    if x_data!=None:
        print (f"Producing {len(x_data)} samples...from image conditingig x_data  ...")
        lenn_val=len(x_data)
        # print (x_data)
    
    print ('Device: ', model.device)
    
    
    for iisample in range (lenn_val):
        print(f"Working on {iisample}")
        X_cond=None
        
        if X_string==None and X!=None: # for X channel
            X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
        if X_string!=None: # from raw text, ie., X_string: need tokenizer_X and Xnormfac
            # -
            # X = tokenizer_X.texts_to_sequences(X_string[iisample])
            # X= sequence.pad_sequences(X,  maxlen=max_length, padding='post', truncating='post')  
            # X=np.array(X)
            # X_cond=torch.from_numpy(X).float()/Xnormfac
            # +
            XX = tokenizer_X_forTextCondi.texts_to_sequences(X_string[iisample])
            XX = sequence.pad_sequences(XX,  maxlen=max_text_len, padding='post', truncating='post')  
            XX = np.array(XX)
            X_cond = torch.from_numpy(XX).float()/Xnormfac_forTextCondi
            
            print ('Tokenized and processed: ', X_cond)
        
        if X_cond!=None:
            if normalize_X_cond_to_one: # used when there is constrain on X_cond.sum()
                X_cond=X_cond/X_cond.sum()
        
            print ("Text conditoning used: ", X_cond, "...sum: ", X_cond.sum(), "cond scale: ", cond_scales)
        else:
            print ("Text conditioning used: None")
        
        # for now, assume image_condi and text_condi can be used at the same time
        if tokenizer_X_forImageCondi==None:
            # ===========================================================
            # condi_image/seq needs no tokenization, like numbers: force_path
            # only normalization needed
            # Based on ModelB:Force_Path
            if x_data!=None:
                x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac_forImageCondi)
                x_data_tokenized=x_data_tokenized.to(torch.float)
                # + for debug:
                if CKeys['Debug_TrainerPack']==1:
                    print("x_data_tokenized dim: ", x_data_tokenized.shape)
                    print("x_data_tokenized dtype: ", x_data_tokenized.dtype)
                    print("test: ", x_data_tokenized!=None)
            else:
                x_data_tokenized=None
                # + for debug:
                if CKeys['Debug_TrainerPack']==1:
                    print("x_data_tokenized and x_data: None")
            
            # model.sample:full arguments
        # self, 
        # x=None, 
        # stop_at_unet_number=1,
        # cond_scale=7.5,
        # # ++
        # x_data=None, # image_condi data
        # skip_steps=None,
        # inpaint_images = None,
        # inpaint_masks = None,
        # inpaint_resample_times = 5,
        # init_images = None,
        # x_data_tokenized=None,
        # tokenizer_X=None,
        # Xnormfac=1.,
        # # -+
        # device=None,
        # max_length=1., # for XandY data, in image/sequence format; NOT for text condition
        # max_text_len=1., # for X data, in text format
            #
            result=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=cond_scales ,
                x_data=None, 
                # ++
                x_data_tokenized=x_data_tokenized,
                #
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,
                device=model.device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X,
                Xnormfac=Xnormfac_forImageCondi, # Xnormfac,
                # ynormfac=ynormfac, 
                max_length=max_length_Y, # for ImageCondi, max_length,
                max_text_len=max_text_len,
            )
        else:
            #
            result=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=cond_scales ,
                x_data=x_data[iisample], 
                # ++
                x_data_tokenized=None,
                #
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,
                device=model.device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X,
                Xnormfac=Xnormfac_forImageCondi, # Xnormfac,
                # ynormfac=ynormfac, 
                max_length=max_length_Y, # max_length,
                max_text_len=max_text_len,
            )
            
        # # ------------------------------------------    
        # result=model.sample ( 
        #     X_cond,
        #     stop_at_unet_number=train_unet_number,
        #     cond_scale=cond_scales 
        # )
            
        result=torch.round(result*ynormfac)
        # + for debug
        print("result.dim: ", result.shape)
        
        fig=plt.figure()
        plt.plot (
            result[0,0,:].cpu().detach().numpy(),
            label= f'Predicted'
        )
        #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}')
        plt.legend()
        outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg"
        #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}")
        if IF_showfig==1:
            plt.show ()
        else:
            plt.savefig(outname, dpi=200)
        plt.close()
        
        # # ----------------------------------------
        # plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted')
        # plt.legend()
        # outname = prefix+ f"sampld_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg"
        # plt.savefig(outname, dpi=200)
        # plt.show ()

        to_rev=result[:,0,:]
        to_rev=to_rev.long().cpu().detach().numpy()
        print("to_rev.dim: ", to_rev.shape)
        y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

        for iii in range (len(y_data_reversed)):
            y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
        
        # + from Model B
        ### reverse second structure input....
        pdb_list=[]
        if X_cond != None: 
            # there is condi_text
            if X_string!=None:
                X_cond=torch.round(X_cond*Xnormfac_forTextCondi)

                to_rev=X_cond[:,:] 
                to_rev=to_rev.long().cpu().detach().numpy()
                print ("to_rev.dim: ", to_rev.shape)
                # --
                # X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)
                # ++
                X_text_reversed=tokenizer_X_forTextCondi.sequences_to_texts (to_rev)
                for iii in range (len(y_text_reversed)):
                    X_text_reversed[iii]=X_text_reversed[iii].upper().strip().replace(" ", "")
                    
            if X_string==None:
                # reverse this: X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
                X_text_reversed=X_cond
        else:
            X_text_reversed=None
                
        if x_data !=None: # there is condi_image
            x_data_reversed=x_data #is already in sequence fromat..
        else:
            x_data_reversed=None
        
        # summary
        # print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,\n predicted sequence: ", y_data_reversed)
        print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,")
        print (f"predicted sequence full: {y_data_reversed}")
        # add just for incase check
        print (f"predicted sequence:      {y_data_reversed[0]}")
        
        # + for debug
        print("================================================")
        print("foldproteins: ", foldproteins)
        
        if not foldproteins:
            pdb_file=None
        else:
        # if foldproteins:
            
            if X_cond != None:
                pass
            
            tempname='temp'
            pdb_file=foldandsavePDB (
                sequence=y_data_reversed[0], 
                filename_out=tempname, 
                num_cycle=num_cycle, 
                flag=flag,
                # +++++++++++++++++++
                # prefix=prefix,
                prefix=sample_dir,
            )
            #
            out_nam=iisample
            #
            out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta'
            write_fasta (y_data_reversed[0], out_nam_fasta) 
            #
            out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb'
            shutil.copy (pdb_file, out_nam) #source, dest
            pdb_file=out_nam
            #
            print (f"Properly named PDB file produced: {pdb_file}")
            if IF_showfig==1:
                #flag=1000
                view=show_pdb(
                    pdb_file=pdb_file, 
                    flag=flag,
                    show_sidechains=show_sidechains, 
                    show_mainchains=show_mainchains, 
                    color=color
                )
                view.show()
                
            if calc_error:
                # only work for ModelA:SecStr
                if CKeys['Problem_ID']==7:
                    get_Model_A_error (pdb_file, X[iisample], plotit=True)
                else:
                    print ("Error calculation on the predicted results is not applicable")
                
    pdb_list.append(pdb_file)
    
    return pdb_list
                
            
#             xbc=X_cond[iisample,:].cpu().detach().numpy()
#             out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.2f" % xbc})+f'_{flag}_{steps}'
#             tempname='temp'
#             pdb_file=foldandsavePDB (sequence=y_data_reversed[0], 
#                                                  filename_out=tempname, 
#                                                  num_cycle=16, flag=flag)
#             out_nam_fasta=f'{prefix}{out_nam}.fasta'
            
#             out_nam=f'{prefix}{out_nam}.pdb'
            
#             write_fasta (y_data_reversed[0], out_nam_fasta)
            
#             shutil.copy (pdb_file, out_nam) #source, dest
#             pdb_file=out_nam
#             print (f"Properly named PDB file produced: {pdb_file}")
          
#             view=show_pdb(pdb_file=pdb_file, flag=flag,
#                           show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color)
#             view.show()

#         if calc_error:
#             get_Model_A_error (pdb_file, X[iisample], plotit=True)
#         return pdb_file

# +++
def sample_sequence_omegafold_pLM_ModelA (
    model,
    X=[[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], # from text conditioning X
    flag=0,
    cond_scales=1.,
    foldproteins=False,
    X_string=None,                                # from text conditioning X_string
    x_data=None,                                  # from image conditioning x_data   
    skip_steps=0,
    inpaint_images=None, # in formation Y data
    inpaint_masks = None,
    inpaint_resample_times = None,
    init_images = None,
    num_cycle=16,          # for omegafolding
    calc_error=False,      # for check on folded results, not used for every case
    # ++++++++++++++++++++++++++
    # tokenizers
    tokenizer_X_forImageCondi=None, # for x_data
    Xnormfac_forImageCondi=1.,
    tokenizer_X_forTextCondi=None,  # for X if NEEDED only
    Xnormfac_forTextCondi=1.,
    tokenizer_y=None,               # for output Y
    ynormfac=1,
    # length
    train_unet_number=1,
    max_length_Y=1,                 # for Y, X_forImageCondi
    max_text_len=1,                 # for    X_forTextCondi
    # other info
    steps=None,
    e=None,
    sample_dir=None,
    prefix=None,
    IF_showfig=True,
    CKeys=None,
    # TBA to Model B
    normalize_X_cond_to_one=False,
    # ++
    pLM_Model=None, # pLM_Model,
    pLM_Model_Name=None, # pLM_Model_Name,
    image_channels=None, # image_channels,
    pLM_alphabet=None, # esm_alphabet,
):
    # -----------
    # steps=0
    # e=flag

    # --
    # print (f"Producing {len(X)} samples...")
    # ++
    if X!=None:
        print (f"Producing {len(X)} samples...from text conditioning X...")
        lenn_val=len(X)
    if X_string!=None:
        lenn_val=len(X_string)
        print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...")
    if x_data!=None:
        print (f"Producing {len(x_data)} samples...from image conditingig x_data  ...")
        lenn_val=len(x_data)
        # print (x_data)
    
    print ('Device: ', model.device)
    
    pdb_list=[]
    fasta_list=[]
    
    for iisample in range (lenn_val):
        print(f"Working on {iisample}")
        X_cond=None
        
        if X_string==None and X!=None: # for X channel
            X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
        if X_string!=None: # from raw text, ie., X_string: need tokenizer_X and Xnormfac
            # -
            # X = tokenizer_X.texts_to_sequences(X_string[iisample])
            # X= sequence.pad_sequences(X,  maxlen=max_length, padding='post', truncating='post')  
            # X=np.array(X)
            # X_cond=torch.from_numpy(X).float()/Xnormfac
            # +
            XX = tokenizer_X_forTextCondi.texts_to_sequences(X_string[iisample])
            XX = sequence.pad_sequences(XX,  maxlen=max_text_len, padding='post', truncating='post')  
            XX = np.array(XX)
            X_cond = torch.from_numpy(XX).float()/Xnormfac_forTextCondi
            
            print ('Tokenized and processed: ', X_cond)
        
        if X_cond!=None:
            if normalize_X_cond_to_one: # used when there is constrain on X_cond.sum()
                X_cond=X_cond/X_cond.sum()
        
            print ("Text conditoning used: ", X_cond, "...sum: ", X_cond.sum(), "cond scale: ", cond_scales)
        else:
            print ("Text conditioning used: None")
        
        # for now, assume image_condi and text_condi can be used at the same time
        if tokenizer_X_forImageCondi==None:
            # ===========================================================
            # condi_image/seq needs no tokenization, like numbers: force_path
            # only normalization needed
            # Based on ModelB:Force_Path
            if x_data!=None:
                x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac_forImageCondi)
                x_data_tokenized=x_data_tokenized.to(torch.float)
                # + for debug:
                if CKeys['Debug_TrainerPack']==1:
                    print("x_data_tokenized dim: ", x_data_tokenized.shape)
                    print("x_data_tokenized dtype: ", x_data_tokenized.dtype)
                    print("test: ", x_data_tokenized!=None)
            else:
                x_data_tokenized=None
                # + for debug:
                if CKeys['Debug_TrainerPack']==1:
                    print("x_data_tokenized and x_data: None")
            
            # model.sample:full arguments
        # self, 
        # x=None, 
        # stop_at_unet_number=1,
        # cond_scale=7.5,
        # # ++
        # x_data=None, # image_condi data
        # skip_steps=None,
        # inpaint_images = None,
        # inpaint_masks = None,
        # inpaint_resample_times = 5,
        # init_images = None,
        # x_data_tokenized=None,
        # tokenizer_X=None,
        # Xnormfac=1.,
        # # -+
        # device=None,
        # max_length=1., # for XandY data, in image/sequence format; NOT for text condition
        # max_text_len=1., # for X data, in text format
            #
            result_embedding=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=cond_scales ,
                x_data=None, 
                # ++
                x_data_tokenized=x_data_tokenized,
                #
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,
                device=model.device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X,
                Xnormfac=Xnormfac_forImageCondi, # Xnormfac,
                # ynormfac=ynormfac, 
                max_length=max_length_Y, # for ImageCondi, max_length,
                max_text_len=max_text_len,
            )
        else:
            # this is for model B in the future
            # two channels should be provided: raw cond_img+img_tokenizer or tokenized_cond_img
            # need to BE UPDATE and merge with code from 
            # fun.sample_sequence_omegafold_pLM_ModelB
            # one branch is currently missing
            result_embedding=model.sample ( 
                x=X_cond,
                stop_at_unet_number=train_unet_number ,
                cond_scale=cond_scales ,
                x_data=x_data[iisample],  
                # ++
                x_data_tokenized=None,
                #
                skip_steps=skip_steps,
                inpaint_images = inpaint_images,
                inpaint_masks = inpaint_masks,
                inpaint_resample_times = inpaint_resample_times,
                init_images = init_images,
                device=model.device,
                # ++++++++++++++++++++++++++
                tokenizer_X=tokenizer_X_forImageCondi, # tokenizer_X,
                Xnormfac=Xnormfac_forImageCondi, # Xnormfac,
                # ynormfac=ynormfac, 
                max_length=max_length_Y, # max_length,
                max_text_len=max_text_len,
            )
            
        # # ------------------------------------------    
        # result=model.sample ( 
        #     X_cond,
        #     stop_at_unet_number=train_unet_number,
        #     cond_scale=cond_scales 
        # )
       
    
        # # -----------------------------------------------
        # result=torch.round(result*ynormfac)
        # +++++++++++++++++++++++++++++++++++++++++++++++
        # ++ for pLM
        # full record
        # result_embedding as image.dim: [batch, channels, seq_len]
        # result_tokens.dim: [batch, seq_len]
        result_tokens,result_logits = convert_into_tokens(
            pLM_Model, 
            result_embedding,
            pLM_Model_Name,
        )
        result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len]
        
        # + for debug
        print("result.dim: ", result.shape)
        
        fig=plt.figure()
        plt.plot (
            result[0,0,:].cpu().detach().numpy(),
            label= f'Predicted'
        )
        #plt.plot (GT[samples,0,:]*ynormfac,label= f'GT {0}')
        plt.legend()
        outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg"
        #plt.title (f"Sample {samples}, cond scale={str (cond_scales[iisample])}")
        if IF_showfig==1:
            plt.show ()
        else:
            plt.savefig(outname, dpi=200)
        plt.close()
        
        # # ----------------------------------------
        # plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted')
        # plt.legend()
        # outname = prefix+ f"sampld_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg"
        # plt.savefig(outname, dpi=200)
        # plt.show ()
        
#         # ---------------------------------------------------------
#         to_rev=result[:,0,:]
#         to_rev=to_rev.long().cpu().detach().numpy()
#         print("to_rev.dim: ", to_rev.shape)
#         y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

#         for iii in range (len(y_data_reversed)):
#             y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        to_rev=result[:,0,:]
        # the following fun decides ending automatically
        y_data_reversed=decode_many_ems_token_rec_for_folding(
            to_rev,
            result_logits,
            pLM_alphabet,
            pLM_Model,
        )
        if CKeys['Debug_TrainerPack']==3:
            print("on y_data_reversed[0]: ", y_data_reversed[0])
        
        
        # + from Model B
        ### reverse second structure input....
        
        #
        if X_cond != None: 
            # there is condi_text
            if X_string!=None:
                X_cond=torch.round(X_cond*Xnormfac_forTextCondi)

                to_rev=X_cond[:,:] 
                to_rev=to_rev.long().cpu().detach().numpy()
                print ("to_rev.dim: ", to_rev.shape)
                # --
                # X_data_reversed=tokenizer_X.sequences_to_texts (to_rev)
                # ++
                X_text_reversed=tokenizer_X_forTextCondi.sequences_to_texts (to_rev)
                for iii in range (len(y_text_reversed)):
                    X_text_reversed[iii]=X_text_reversed[iii].upper().strip().replace(" ", "")
                    
            if X_string==None:
                # reverse this: X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0)
                X_text_reversed=X_cond
        else:
            X_text_reversed=None
                
        if x_data !=None: # there is condi_image
            x_data_reversed=x_data #is already in sequence fromat..
        else:
            x_data_reversed=None
        
        # summary
        print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,\n predicted sequence: ", y_data_reversed)
        
        # + for debug
        print("================================================")
        print("foldproteins: ", foldproteins)
        
        if not foldproteins:
            pdb_file=None
        else:
        # if foldproteins:
            
            if X_cond != None:
                pass
            
            tempname='temp'
            pdb_file, fasta_file=foldandsavePDB_pdb_fasta (
                sequence=y_data_reversed[0], 
                filename_out=tempname, 
                num_cycle=num_cycle, 
                flag=flag,
                # +++++++++++++++++++
                # prefix=prefix,
                prefix=sample_dir,
            )
            #
            out_nam=iisample
            #
            # out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta'
            # write_fasta (y_data_reversed[0], out_nam_fasta) 
            #
            out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb'
            out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta'
            shutil.copy (pdb_file, out_nam) #source, dest
            shutil.copy (fasta_file, out_nam_fasta)
            # clean the slade to avoid mistakenly using the previous fasta file
            os.remove (pdb_file)
            os.remove (fasta_file)
            #
            pdb_file=out_nam
            fasta_file=out_nam_fasta
            #
            pdb_list.append(pdb_file)
            fasta_list.append(fasta_file)
            #
            print (f"Properly named PDB file produced: {pdb_file}")
            if IF_showfig==1:
                #flag=1000
                view=show_pdb(
                    pdb_file=pdb_file, 
                    flag=flag,
                    show_sidechains=show_sidechains, 
                    show_mainchains=show_mainchains, 
                    color=color
                )
                view.show()
                
            if calc_error:
                if CKeys['Problem_ID']==7:
                    # only work for ModelA:SecStr
                    get_Model_A_error (pdb_file, X[iisample], plotit=True)
                else:
                    print("Error calculation on the predicted results is not applicable...")
                
        
    
    return pdb_list, fasta_list

    
# + TBU
# ++++++++++++++++++++++++++++++++++++++++++++++++
def sample_loop_omegafold_ModelA (
    model,
    train_loader,
    cond_scales=None, # [7.5], #list of cond scales - each sampled...
    num_samples=None, # 2, #how many samples produced every time tested.....
    timesteps=None, # 100, # not used
    flag=None, # 0,
    foldproteins=False,
    #
    cond_image=False, # use_text_embedd=True,
    cond_text=True, 
    skip_steps=0,
    #
    max_text_len=None,
    max_length=None,
    # +++++++++++++++++++
    train_unet_number=1,
    ynormfac=None,
    prefix=None,
    tokenizer_y=None,
    Xnormfac_CondiText=1,
    tokenizer_X_CondiText=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
):
    # =====================================================
    # sample # = num_samples*(# of mini-batches)
    # =====================================================
    # steps=0
    # e=flag
    # for item  in train_loader:
    for idx, item  in enumerate(train_loader):

        X_train_batch= item[0].to(device)
        y_train_batch=item[1].to(device)

        GT=y_train_batch.cpu().detach() 

        GT= GT.unsqueeze(1)
        if num_samples>y_train_batch.shape[0]:
            print("Warning: sampling # > len(mini_batch)")

        num_samples = min (num_samples,y_train_batch.shape[0] )
        print (f"Producing {num_samples} samples...")
        X_train_batch_picked = X_train_batch[:num_samples,:]
        print ('(TEST) X_batch shape: ', X_train_batch_picked.shape)

        # loop over cond_scales:list
        for iisample in range (len (cond_scales)):

            # ++ for model A
            result=model.sample (
                x=X_train_batch_picked,
                stop_at_unet_number=train_unet_number,
                cond_scale=cond_scales[iisample],
                #
                skip_steps=skip_steps,
                device=model.device,
                #
                max_length=max_length,
                max_text_len=max_text_len,
            )
            # # ++ for model B
            # if use_text_embedd:
            #     result=model.sample (
            #         # x= X_train_batch,
            #         x= X_train_batch_picked,
            #         stop_at_unet_number=train_unet_number ,
            #         cond_scale=cond_scales[iisample], 
            #         device=device, 
            #         skip_steps=skip_steps
            #     )
            # else:
            #     result=model.sample (
            #         x= None, 
            #         # x_data_tokenized= X_train_batch,
            #         x_data_tokenized= X_train_batch_picked,
            #         stop_at_unet_number=train_unet_number ,
            #         cond_scale=cond_scales[iisample],
            #         device=device,
            #         skip_steps=skip_steps
            #     )
        
            result=torch.round(result*ynormfac)
            GT=torch.round (GT*ynormfac)

            for samples in range  (num_samples):
                print ("sample ", samples+1, "out of ", num_samples)

                fig=plt.figure()
                plt.plot (
                    result[samples,0,:].cpu().detach().numpy(),
                    label= f'Predicted'
                )
                plt.plot (
                    GT[samples,0,:],
                    label= f'GT {0}'
                )
                plt.legend()
                outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                if IF_showfig==1:
                    plt.show()
                else:
                    plt.savefig(outname, dpi=200)
                plt.close ()

                #reverse y sequence
                to_rev=result[:,0,:]
                to_rev=to_rev.long().cpu().detach().numpy()

                y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

                for iii in range (len(y_data_reversed)):
                    y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")

                #reverse GT_y sequence
                to_rev=GT[:,0,:]
                to_rev=to_rev.long().cpu().detach().numpy()

                GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

                for iii in range (len(y_data_reversed)):
                    GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "")

                ### reverse second structure input....
                # pay attension to the shape of Xnormfac
                # -
                # to_rev=torch.round (X_train_batch[:,:]*Xnormfac_CondiText)
                # +
                to_rev=torch.round (X_train_batch[:,:]*torch.FloatTensor(Xnormfac_CondiText).to(model.device))
                to_rev=to_rev.long().cpu().detach().numpy()
                
                # ++ different input
                if CKeys['Debug_TrainerPack']==1:
                    print("tokenizer_X_CondiText: ", tokenizer_X_CondiText)
                    print("Xnormfac_CondiText: ", Xnormfac_CondiText)
                    
                if tokenizer_X_CondiText!=None:
                    X_data_reversed=tokenizer_X_CondiText.sequences_to_texts (to_rev)
                    for iii in range (len(y_data_reversed)):
                        X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
                else:
                    X_data_reversed=to_rev.copy()
                # + for debug
                if CKeys['Debug_TrainerPack']==1:
                    print("X_data_reversed: ", X_data_reversed)
                

                print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples])
                print (f"Ground truth: {GT_y_data_reversed[samples]}")

                if foldproteins:
                    xbc=X_train_batch[samples,:].cpu().detach().numpy()
                    out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})
                    tempname='temp'
                    pdb_file=foldandsavePDB (
                        sequence=y_data_reversed[samples], 
                        filename_out=tempname, 
                        num_cycle=16, flag=flag,
                        # +++++++++++++++++++
                        prefix=prefix
                    )

                    # #out_nam=f'{prefix}{out_nam}.pdb'
                    # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb'
                    # ------------------------------------------------------
                    # sometime, this name below can get too long to fit
                    # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb'
                    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    # add a way to save the sampling name and results
                    # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                    out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb'
                    out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt'
                    
                    if CKeys['Debug_TrainerPack']==1:
                        print("pdb_file: ", pdb_file)
                        print("out_nam: ", out_nam)
                        
                    print (f'Original PDB: {pdb_file} OUT: {out_nam}')
                    shutil.copy (pdb_file, out_nam) #source, dest
                    # +
                    with open(out_nam_inX, "w") as inX_file:
                        inX_file.write(f'{X_data_reversed[samples]}\n')
                        
                    pdb_file=out_nam
                    print (f"Properly named PDB file produced: {pdb_file}")
                    print (f"input X for sampling stored: {pdb_file}")
                    
                    if IF_showfig==1:
                        view=show_pdb(
                            pdb_file=pdb_file, 
                            flag=flag, 
                            show_sidechains=show_sidechains,  
                            show_mainchains=show_mainchains, 
                            color=color
                        )
                        view.show()

#                 steps=steps+1
                
#         if steps>num_samples:
#             break

# + TBU
# ++++++++++++++++++++++++++++++++++++++++++++++++
def sample_loop_omegafold_pLM_ModelA (
    model,
    train_loader,
    cond_scales=None, # [7.5], #list of cond scales - each sampled...
    num_samples=None, # 2, #how many samples produced every time tested.....
    timesteps=None, # 100, # not used
    flag=None, # 0,
    foldproteins=False,
    #
    cond_image=False, # use_text_embedd=True,
    cond_text=True, 
    skip_steps=0,
    #
    max_text_len=None,
    max_length=None,
    # +++++++++++++++++++
    train_unet_number=1,
    ynormfac=None,
    prefix=None,
    tokenizer_y=None,
    Xnormfac_CondiText=1,
    tokenizer_X_CondiText=None,
    # ++
    CKeys=None,
    sample_dir=None,
    steps=None,
    e=None,
    IF_showfig=True, # effective only after foldproteins=True
    # ++ for pLM
    pLM_Model=None,
    pLM_Model_Name=None,
    image_channels=None,
    pLM_alphabet=None,
    # ++ for on-fly check: for SecStr only
    calc_error=False,      # for check on folded results, not used for every case
):
    # =====================================================
    # sample # = num_samples*(# of mini-batches)
    # =====================================================
    # steps=0
    # e=flag
    # for item  in train_loader:
    for idx, item  in enumerate(train_loader):

        X_train_batch= item[0].to(device)
        y_train_batch=item[1].to(device)

        GT=y_train_batch.cpu().detach() 

        GT= GT.unsqueeze(1)
        if num_samples>y_train_batch.shape[0]:
            print("Warning: sampling # > len(mini_batch)")

        num_samples = min (num_samples,y_train_batch.shape[0] )
        print (f"Producing {num_samples} samples...")
        X_train_batch_picked = X_train_batch[:num_samples,:]
        print ('(TEST) X_batch shape: ', X_train_batch_picked.shape)

        # loop over cond_scales:list
        for iisample in range (len (cond_scales)):

            # ++ for model A
            result_embedding = model.sample (
                x=X_train_batch_picked,
                stop_at_unet_number=train_unet_number,
                cond_scale=cond_scales[iisample],
                #
                skip_steps=skip_steps,
                device=model.device,
                #
                max_length=max_length,
                max_text_len=max_text_len,
                #
                x_data=None,
                x_data_tokenized=None,
                #
                tokenizer_X=tokenizer_X_CondiText,
                Xnormfac=Xnormfac_CondiText,
            )
            # # ++ for model B
            # if use_text_embedd:
            #     result=model.sample (
            #         # x= X_train_batch,
            #         x= X_train_batch_picked,
            #         stop_at_unet_number=train_unet_number ,
            #         cond_scale=cond_scales[iisample], 
            #         device=device, 
            #         skip_steps=skip_steps
            #     )
            # else:
            #     result=model.sample (
            #         x= None, 
            #         # x_data_tokenized= X_train_batch,
            #         x_data_tokenized= X_train_batch_picked,
            #         stop_at_unet_number=train_unet_number ,
            #         cond_scale=cond_scales[iisample],
            #         device=device,
            #         skip_steps=skip_steps
            #     )
            
            # ++ for pLM:
            # full record
            # result_embedding as image.dim: [batch, channels, seq_len]
            # result_tokens.dim: [batch, seq_len]
            result_tokens,result_logits = convert_into_tokens(
                pLM_Model, 
                result_embedding,
                pLM_Model_Name,
            )
        
            # # --------------------------------------------
            # result=torch.round(result*ynormfac)
            # GT=torch.round (GT*ynormfac)
            # ++++++++++++++++++++++++++++++++++++++++++++
            result=result_tokens.unsqueeze(1) # dim: [batch, 1, seq_len]
            
#                 # ---------------------------------------------------------
#                 #reverse y sequence
#                 to_rev=result[:,0,:]
#                 to_rev=to_rev.long().cpu().detach().numpy()

#                 y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

#                 for iii in range (len(y_data_reversed)):
#                     y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
            # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            to_rev=result[:,0,:] # token (batch,seq_len)
            y_data_reversed=decode_many_ems_token_rec_for_folding(
                to_rev,
                result_logits,
                pLM_alphabet,
                pLM_Model,
            )
            if CKeys['Debug_TrainerPack']==3:
                print("on y_data_reversed[0]: ", y_data_reversed[0])
                

#                 # -----------------------------------------------------------
#                 #reverse GT_y sequence
#                 to_rev=GT[:,0,:]
#                 to_rev=to_rev.long().cpu().detach().numpy()

#                 GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev)

#                 for iii in range (len(y_data_reversed)):
#                     GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "")
            # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            #reverse GT_y sequence
            # GT should be SAFE to reverse
            to_rev=GT[:,0,:]
            GT_y_data_reversed=decode_many_ems_token_rec(
                to_rev,
                pLM_alphabet,
            )

            ### reverse second structure input....
            # pay attension to the shape of Xnormfac
            # -
            # to_rev=torch.round (X_train_batch[:,:]*Xnormfac_CondiText)
            # +
            # print("X_train_batch", X_train_batch)
            # print("Xnormfac_CondiText: ", Xnormfac_CondiText)

            # to_rev=torch.round (X_train_batch[:,:]*torch.FloatTensor(Xnormfac_CondiText).to(model.device))
            # print("X_train_batch: ", X_train_batch[:,:])
            # print("torch.tensor(Xnormfac_CondiText): ", torch.tensor(Xnormfac_CondiText))
            to_rev=X_train_batch[:,:]*torch.tensor(Xnormfac_CondiText).to(model.device)
            # print("to_rev ", to_rev)
            # # -: convert into int64
            # to_rev=to_rev.long().cpu().detach().numpy()
            # +: just float
            to_rev=to_rev.cpu().detach().numpy()
            # print("to_rev 2", to_rev)

            # ++ different input
            if CKeys['Debug_TrainerPack']==1:
                print("tokenizer_X_CondiText: ", tokenizer_X_CondiText)
                print("Xnormfac_CondiText: ", Xnormfac_CondiText)

            if tokenizer_X_CondiText!=None:
                # round the number into tokens
                to_rev = np.round(to_rev)
                X_data_reversed=tokenizer_X_CondiText.sequences_to_texts (to_rev)
                for iii in range (len(y_data_reversed)):
                    X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "")
            else:
                X_data_reversed=to_rev.copy()
            # + for debug
            if CKeys['Debug_TrainerPack']==1:
                print("X_data_reversed: ", X_data_reversed)
                print("X_data_reversed.dim: ", X_data_reversed.shape)

            for samples in range  (num_samples):
                print ("sample ", samples+1, "out of ", num_samples)

                fig=plt.figure()
                plt.plot (
                    result[samples,0,:].cpu().detach().numpy(),
                    label= f'Predicted'
                )
                plt.plot (
                    GT[samples,0,:],
                    label= f'GT {0}'
                )
                plt.legend()
                outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                if IF_showfig==1:
                    plt.show()
                else:
                    plt.savefig(outname, dpi=200)
                plt.close ()
                

                print (f"For input in dataloader: {X_train_batch[samples,:].cpu().detach().numpy()} or \n recovered input {X_data_reversed[samples]}")
                print (f"predicted sequence: {y_data_reversed[samples]}")
                print (f"Ground truth:       {GT_y_data_reversed[samples]}")

                if foldproteins:
                    # check whether the predicted sequence is valid
                    if len(y_data_reversed[samples])>0:
                        # # --
                        # xbc=X_train_batch[samples,:].cpu().detach().numpy()
                        # # out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})
                        # out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})
                        # ++
                        xbc=X_data_reversed[samples]
                        out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})
                        #
                        tempname='temp'
                        pdb_file,fasta_file=foldandsavePDB_pdb_fasta (
                            sequence=y_data_reversed[samples], 
                            filename_out=tempname, 
                            num_cycle=16, flag=flag,
                            # +++++++++++++++++++
                            prefix=prefix
                        )

                        # #out_nam=f'{prefix}{out_nam}.pdb'
                        # out_nam=f'{prefix}{X_data_reversed[samples]}.pdb'
                        # ------------------------------------------------------
                        # sometime, this name below can get too long to fit
                        # out_nam=f'{sample_dir}{X_data_reversed[samples]}.pdb'
                        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++
                        # add a way to save the sampling name and results
                        # ref: outname = sample_dir+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg"
                        out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb'
                        out_nam_seq=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.fasta'
                        out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt'

                        if CKeys['Debug_TrainerPack']==1:
                            print("pdb_file: ", pdb_file)
                            print("out_nam: ", out_nam)

                        print (f'Original PDB: {pdb_file} OUT: {out_nam}')
                        shutil.copy (pdb_file, out_nam) #source, dest
                        shutil.copy (fasta_file, out_nam_seq)
                        # +
                        with open(out_nam_inX, "w") as inX_file:
                            # inX_file.write(f'{X_data_reversed[samples]}\n')
                            inX_file.write(out_nam_content)
                        # clean the slade to avoid mistakenly using the previous fasta file
                        os.remove (pdb_file)
                        os.remove (fasta_file)


                        pdb_file=out_nam
                        print (f"Properly named PDB file produced: {pdb_file}")
                        print (f"input X for sampling stored: {pdb_file}")

                        if IF_showfig==1:
                            view=show_pdb(
                                pdb_file=pdb_file, 
                                flag=flag, 
                                show_sidechains=show_sidechains,  
                                show_mainchains=show_mainchains, 
                                color=color
                            )
                            view.show()
                            
                        if calc_error:
                            print('On-fly check...')
                            if CKeys['Problem_ID']==7:
                                # only work for ModelA:SecStr
                                get_Model_A_error (pdb_file, X_data_reversed[samples], plotit=True)
                            else:
                                print("Error calculation on the predicted results is not applicable...")
                    
                            
                    else:
                        print("The predicted sequence is EMPTY...")