from diffusers import (
    DPMSolverMultistepScheduler,
    DDPMScheduler,
    DDIMScheduler,
    PNDMScheduler,
    DEISMultistepScheduler,
)
import torch
import yaml
import math
import tqdm
import time


class DiffusePipeline(object):

    def __init__(
        self,
        opt,
        model,
        diffuser_name,
        num_inference_steps,
        device,
        torch_dtype=torch.float16,
    ):
        self.device = device
        self.torch_dtype = torch_dtype
        self.diffuser_name = diffuser_name
        self.num_inference_steps = num_inference_steps
        if self.torch_dtype == torch.float16:
            model = model.half()
        print("Xxx",device,"xxxx")
        self.model = model.to(device)
        self.opt = opt

        # Load parameters from YAML file
        with open("config/diffuser_params.yaml", "r") as yaml_file:
            diffuser_params = yaml.safe_load(yaml_file)

        # Select diffusion'parameters based on diffuser_name
        if diffuser_name in diffuser_params:
            params = diffuser_params[diffuser_name]
            scheduler_class_name = params["scheduler_class"]
            additional_params = params["additional_params"]

            # align training parameters
            additional_params["num_train_timesteps"] = opt.diffusion_steps
            additional_params["beta_schedule"] = opt.beta_schedule
            additional_params["prediction_type"] = opt.prediction_type

            try:
                scheduler_class = globals()[scheduler_class_name]
            except KeyError:
                raise ValueError(f"Class '{scheduler_class_name}' not found.")

            self.scheduler = scheduler_class(**additional_params)
        else:
            raise ValueError(f"Unsupported diffuser_name: {diffuser_name}")

    def generate_batch(self, caption, m_lens):
        B = len(caption)
        T = m_lens.max()
        shape = (B, T, self.model.input_feats)

        # random sampling noise x_T
        sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype)

        # set timesteps
        self.scheduler.set_timesteps(self.num_inference_steps, self.device)
        timesteps = [
            torch.tensor([t] * B, device=self.device).long()
            for t in self.scheduler.timesteps
        ]

        # cache text_embedded
        enc_text = self.model.encode_text(caption, self.device)

        for i, t in enumerate(timesteps):
            # 1. model predict
            with torch.no_grad():
                if getattr(self.model, "cond_mask_prob", 0) > 0:
                    predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text)
                else:

                    predict = self.model(sample, t, enc_text=enc_text)

            # 2. compute less noisy motion and set x_t -> x_t-1
            sample = self.scheduler.step(predict, t[0], sample).prev_sample

        return sample

    def generate(self, caption, m_lens, batch_size=32):
        N = len(caption)
        infer_mode = ""
        if getattr(self.model, "cond_mask_prob", 0) > 0:
            infer_mode = "classifier-free-guidance"
        print(
            f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps."
        )
        self.model.eval()

        all_output = []
        t_sum = 0
        cur_idx = 0
        for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))):
            if cur_idx + batch_size >= N:
                batch_caption = caption[cur_idx:]
                batch_m_lens = m_lens[cur_idx:]
            else:
                batch_caption = caption[cur_idx : cur_idx + batch_size]
                batch_m_lens = m_lens[cur_idx : cur_idx + batch_size]
            torch.cuda.synchronize()
            start_time = time.time()
            output = self.generate_batch(batch_caption, batch_m_lens)
            torch.cuda.synchronize()
            now_time = time.time()

            # The average inference time is calculated after GPU warm-up in the first 50 steps.
            if (bacth_idx + 1) * self.num_inference_steps >= 50:
                t_sum += now_time - start_time

            # Crop motion with gt/predicted motion length
            B = output.shape[0]
            for i in range(B):
                all_output.append(output[i, : batch_m_lens[i]])

            cur_idx += batch_size

        # calcalate average inference time
        t_eval = t_sum / (bacth_idx - 1)
        print(
            "The average generation time of a batch motion (bs=%d) is %f seconds"
            % (batch_size, t_eval)
        )
        return all_output, t_eval