import torch
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.sampling import *


def init_model(cfgs):

    model_cfg = OmegaConf.load(cfgs.model_cfg_path)
    ckpt = cfgs.load_ckpt_path

    model = instantiate_from_config(model_cfg.model)
    model.init_from_ckpt(ckpt)

    if cfgs.type == "train":
        model.train()
    else:
        model.to(torch.device("cuda", index=cfgs.gpu))
        model.eval()
        model.freeze()

    return model

def init_sampling(cfgs):

    discretization_config = {
        "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
    }

    guider_config = {
        "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
        "params": {"scale": cfgs.scale[0]},
    }

    sampler = EulerEDMSampler(
        num_steps=cfgs.steps,
        discretization_config=discretization_config,
        guider_config=guider_config,
        s_churn=0.0,
        s_tmin=0.0,
        s_tmax=999.0,
        s_noise=1.0,
        verbose=True,
        device=torch.device("cuda", index=cfgs.gpu)
    )

    return sampler

def deep_copy(batch):

    c_batch = {}
    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            c_batch[key] = torch.clone(batch[key])
        elif isinstance(batch[key], (tuple, list)): 
            c_batch[key] = batch[key].copy()
        else:
            c_batch[key] = batch[key]
    
    return c_batch

def prepare_batch(cfgs, batch):

    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))

    batch_uc = batch

    return batch, batch_uc