import numpy as np
import cv2
import os
import torch
import tqdm
import time
import imageio
from sim.simulator import GenieSimulator, RobomimicSimulator
from diffusion_policy.util.pytorch_util import dict_apply
from sim.policy import DiffusionPolicy

DP_RES = 84
MAX_STEPS = 100
NUM_EVAL_TRIALS = 50


if __name__ == '__main__':
    robomimic_simulator = RobomimicSimulator(env_name='lift')
    genie_simulator = GenieSimulator(

        image_encoder_type='temporalvae',
        image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
        quantize=False,
        backbone_type="stmar",
        backbone_ckpt="data/mar_ckpt/robomimic_mixed",
        prompt_horizon=11,
        
        action_stride=1,
        domain='robomimic',
        physics_simulator=robomimic_simulator,
        physics_simulator_teacher_force=None,
    )
    assert genie_simulator.action_stride == 1, "currently only support action stride of 1"

    # load the policy
    success_rates = [1.00, 0.70, 0.52, 0.38]
    eval_time_taken = [0.0] * len(success_rates)
    for index, sr in enumerate(success_rates):
        diffusion_policy = DiffusionPolicy(f'data/dp_ckpt/dp_lift_sr{sr:.2f}.ckpt')
        n_obs_steps = diffusion_policy.n_obs_steps
        
        for trial in range(NUM_EVAL_TRIALS):

            # reset
            genie_image = genie_simulator.reset()

            # obs dict construction
            latest_obs_dict = {
                'agentview_image': cv2.resize(
                    genie_image[:, :genie_image.shape[1]//2],
                    (DP_RES, DP_RES)
                ).transpose(2, 0, 1),
            }
            obs_dict_buf = dict_apply(latest_obs_dict, lambda x : x[np.newaxis].repeat(n_obs_steps, axis=0))

            done = False
            pbar = tqdm.tqdm(total=MAX_STEPS)
            simulation_frames = [ genie_image ]
            start_time = time.time()
            while not done and pbar.n < MAX_STEPS:

                # get the latest observation
                latest_obs_dict = {
                    'agentview_image': cv2.resize(
                        genie_image, 
                        (DP_RES, DP_RES)
                    ).transpose(2, 0, 1),
                }

                # roll the obs dict buffer
                obs_dict_buf = dict_apply(
                    obs_dict_buf,
                    lambda x : np.roll(x, shift=-1, axis=0)
                )

                # update the obs dict buffer with the latest observation
                for k, v in latest_obs_dict.items():
                    obs_dict_buf[k][-1] = v

                # rollout
                traj = diffusion_policy.generate_action(dict_apply(
                    obs_dict_buf,
                    lambda x : torch.from_numpy(x).to(
                        device=diffusion_policy.device, dtype=diffusion_policy.dtype
                        ).unsqueeze(0)
                ))['action'].squeeze(0).detach().cpu().numpy()

                # step the simulator
                for action in traj:
                    result = genie_simulator.step(action[np.newaxis])
                    done = result['done']
                    genie_image = result['pred_next_frame']
                    phys_image = result['gt_next_frame']
                    simulation_frames.append(np.concatenate([genie_image, phys_image], axis=1))
                    pbar.update(1)
            
            pbar.close()
            
            end_time = time.time()
            eval_time_taken[index] += end_time - start_time

            # save the simulation frames
            os.makedirs(f'data/policy_eval_videos/policy_{sr:.2f}', exist_ok=True)
            print(f"Saving {len(simulation_frames)} frames to data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4")
            imageio.mimsave(f'data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4', simulation_frames, fps=10)

        print(f"This checkpoint took {eval_time_taken[index]} seconds to evaluate")

    print("======= Evaluation Time Taken =======")
    for sr, t in zip(success_rates, eval_time_taken):
        print(f"SR={sr:.2f}: {t:.2f} seconds")
    print(f"Average time taken per eval: {np.mean(eval_time_taken) / NUM_EVAL_TRIALS:.2f} seconds")

    print("======= Simulation Done =======")
    genie_simulator.close()