import warnings

warnings.filterwarnings('ignore', category=DeprecationWarning)

import os

os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'

from pathlib import Path
from collections import defaultdict

import hydra
import numpy as np
import torch
import wandb
from dm_env import specs

import tools.utils as utils
from tools.logger import Logger
from tools.replay import ReplayBuffer, make_replay_loader

torch.backends.cudnn.benchmark = True

def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg):
    cfg.obs_type = obs_type
    cfg.obs_shape = obs_spec.shape
    cfg.action_shape = action_spec.shape
    cfg.num_expl_steps = num_expl_steps
    return hydra.utils.instantiate(cfg)


def make_dreamer_agent(obs_space, action_spec, cur_config, cfg):
    from copy import deepcopy
    cur_config = deepcopy(cur_config)
    if hasattr(cur_config, 'agent'):
        del cur_config.agent
    return hydra.utils.instantiate(cfg, cfg=cur_config, obs_space=obs_space, act_spec=action_spec)

class Workspace:
    def __init__(self, cfg, savedir=None, workdir=None,):
        self.workdir = Path.cwd() if workdir is None else workdir
        print(f'workspace: {self.workdir}')

        self.cfg = cfg
        
        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        # create logger
        self.logger = Logger(self.workdir,
                             use_tb=cfg.use_tb,
                             use_wandb=cfg.use_wandb)
        # create envs
        self.task = task = cfg.task
        img_size = cfg.img_size

        import envs.main as envs
        self.train_env = envs.make(task, cfg.obs_type, cfg.action_repeat, cfg.seed, img_size=img_size,  viclip_encode=cfg.viclip_encode, clip_hd_rendering=cfg.clip_hd_rendering) 


        # # create agent 
        sample_agent = make_dreamer_agent(self.train_env.obs_space, self.train_env.act_space['action'], cfg, cfg.agent)

        # create replay buffer
        data_specs = (self.train_env.obs_space,
                      self.train_env.act_space,
                      specs.Array((1,), np.float32, 'reward'),
                      specs.Array((1,), np.float32, 'discount'))

        if cfg.train_from_data:
            # Loading replay buffer
            if cfg.replay_from_wandb_project is not None:
                api = wandb.Api()
                project_name = cfg.replay_from_wandb_project
                params2search = {
                    "task" : cfg.task if cfg.task_snapshot is None else cfg.task_snapshot,
                    "seed" : cfg.seed if cfg.seed_snapshot is None else cfg.seed_snapshot,
                }
                runs = api.runs(f"PUT_YOUR_USER_HERE/{project_name}")
                found = False
                for run in runs:
                    if np.all([ v == run.config.get(k, None) for k,v in params2search.items()]):
                        found = True
                        found_path = Path(run.config['workdir'].replace('/code', ''))
                        break
                if not found:
                    raise Exception("Replay from wandb buffer not found")

                replay_dir = found_path / 'code' / 'buffer'
            else:
                replay_dir = Path(cfg.replay_load_dir)

            # create data storage
            self.replay_storage = ReplayBuffer(data_specs, [],
                                                    replay_dir,
                                                    length=cfg.batch_length, **cfg.replay,
                                                    device=cfg.device, ignore_extra_keys=True, load_recursive=True)
            print('Loaded ', self.replay_storage._loaded_episodes, 'episodes from ', str(replay_dir))

            # create replay buffer
            self.replay_loader = make_replay_loader(self.replay_storage,
                                                    cfg.batch_size,)
            self._replay_iter = None

        # Loading snapshot
        if cfg.snapshot_from_wandb_project is not None:
            api = wandb.Api()
            project_name = cfg.snapshot_from_wandb_project
            params2search = {
                "task" : cfg.task if cfg.task_snapshot is None else cfg.task_snapshot,
                "agent_name" : cfg.agent.name if cfg.agent_name_snapshot is None else cfg.agent_name_snapshot,
                "seed" : cfg.seed if cfg.seed_snapshot is None else cfg.seed_snapshot,
            }
            if cfg.agent.clip_lafite_noise > 0.:
                params2search['clip_lafite_noise'] = cfg.agent.clip_lafite_noise
            if cfg.agent.clip_add_noise > 0.:
                params2search['clip_add_noise'] = cfg.agent.clip_add_noise
            if cfg.reset_connector:
                del params2search['clip_add_noise']
            runs = api.runs(f"PUT_YOUR_USER_HERE/{project_name}")
            found = False
            for run in runs:
                if np.all([ v == run.config.get(k, None) for k,v in params2search.items()]):
                    found = True
                    found_path = Path(run.config['workdir'].replace('/code', ''))
                    break
            if not found:
                raise Exception("Snapshot from wandb not found")

            if cfg.snapshot_step is None:
                snapshot_dir = found_path / 'code' / 'last_snapshot.pt'
            else:
                snapshot_dir = found_path / 'code' / f'snapshot_{cfg.snapshot_step}.pt'
        elif cfg.snapshot_load_dir is not None:
            snapshot_dir = Path(cfg.snapshot_load_dir)
        else:
            snapshot_dir = None

        if snapshot_dir is not None:        
            self.load_snapshot(snapshot_dir, resume=False)
            if self.cfg.reset_world_model:
                self.agent.wm = sample_agent.wm 
                # To reset optimization
                from agent import dreamer_utils as common
                self.agent.wm.model_opt = common.Optimizer('model', self.agent.wm.parameters(), **self.agent.wm.cfg.model_opt, use_amp=self.agent.wm._use_amp)
            if self.cfg.reset_connector:
                self.agent.wm.connector = sample_agent.wm.connector
                # To reset optimization
                from agent import dreamer_utils as common
                self.agent.wm.model_opt = common.Optimizer('model', self.agent.wm.parameters(), **self.agent.wm.cfg.model_opt, use_amp=self.agent.wm._use_amp)

            # overwriting cfg
            self.agent.cfg = sample_agent.cfg
            self.agent.wm.cfg = sample_agent.wm.cfg 
            
            if self.cfg.reset_imag_behavior:
                self.agent.instantiate_imag_behavior()
        else:
            self.agent = sample_agent

        self.eval_env = envs.make(self.task, self.cfg.obs_type, self.cfg.action_repeat, self.cfg.seed, img_size=64, )
        if hasattr(self.eval_env, 'eval_mode'):
            self.eval_env.eval_mode()
        eval_specs = (self.eval_env.obs_space,
                        self.eval_env.act_space,
                        specs.Array((1,), np.float32, 'reward'),
                        specs.Array((1,), np.float32, 'discount'))
        self.eval_storage = ReplayBuffer(eval_specs, {},
                                                self.workdir / 'eval_buffer',
                                                length=cfg.batch_length, **cfg.replay,
                                                device=cfg.device, ignore_extra_keys=True,)
        self.eval_storage._minlen = 1

        self.timer = utils.Timer()
        self._global_step = 0
        self._global_episode = 0

    @property
    def global_step(self):
        return self._global_step

    @property
    def global_episode(self):
        return self._global_episode

    @property
    def global_frame(self):
        return self.global_step * self.cfg.action_repeat

    @property
    def replay_iter(self):
        if self._replay_iter is None:
            self._replay_iter = iter(self.replay_loader)
        return self._replay_iter

    def eval(self):
        import envs.main as envs
        eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
        episode_reward = []
        while eval_until_episode(len(episode_reward)):
            if len(episode_reward) > 0 and self.global_step == 0:
                return
            episode_reward.append(0)
            step, episode = 0, defaultdict(list)
            meta = self.agent.init_meta()
            time_step, dreamer_obs = self.eval_env.reset()
            data = dreamer_obs
            if 'clip_video' in data:
                del data['clip_video']
            self.eval_storage.add(data, meta) 
            agent_state = None
            while not time_step.last():
                with torch.no_grad(), utils.eval_mode(self.agent):
                    action, agent_state = self.agent.act(dreamer_obs, 
                                                meta,
                                                self.global_step,
                                                eval_mode=True,
                                                state=agent_state)
                time_step, dreamer_obs = self.eval_env.step(action)
                for k in dreamer_obs:
                    episode[k].append(dreamer_obs[k])
                episode_reward[-1] += time_step.reward
                if time_step.last():
                    if episode_reward[-1] == np.max(episode_reward):
                        best_episode = {**episode}
                    if episode_reward[-1] == np.min(episode_reward):
                        worst_episode = {**episode}
                data = dreamer_obs
                if 'clip_video' in data:
                    del data['clip_video']
                self.eval_storage.add(data, meta) 
                step += 1

        if self.global_step > 0 and self.global_frame % self.cfg.log_episodes_every_frames == 0:
            # B, T, C, H, W = video.shape
            videos = {'best_episode' : np.stack(best_episode['observation'], axis=0),
                    'worst_episode' : np.stack(worst_episode['observation'], axis=0),}
            self.logger.log_visual(videos, self.global_frame)

        with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
            log('episode_reward', np.mean(episode_reward))
            log('episode_length', step * self.cfg.action_repeat)
            log('episode', self.global_episode)
            log('step', self.global_step)

    def eval_imag_behavior(self,):
        self.agent._backup_acting_behavior = self.agent._acting_behavior
        self.agent._acting_behavior = self.agent._imag_behavior
        self.eval()
        self.agent._acting_behavior = self.agent._backup_acting_behavior

    def train(self):
        # predicates
        train_until_step = utils.Until(self.cfg.num_train_frames, 1)
        eval_every_step = utils.Every(self.cfg.eval_every_frames, 1)
        should_log_scalars = utils.Every(self.cfg.log_every_frames, 1)
        should_save_model = utils.Every(self.cfg.save_every_frames, 1)
        should_log_visual = utils.Every(self.cfg.visual_every_frames, 1)
        metrics = None
        while train_until_step(self.global_step):
            # try to evaluate
            if eval_every_step(self.global_step):
                if self.cfg.eval_modality == 'task':
                    self.eval()
                if self.cfg.eval_modality == 'task_imag':
                    self.eval_imag_behavior()
                if self.cfg.eval_modality == 'from_text':
                    self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame)
                    self.eval_from_text()

            if self.cfg.train_from_data:
                # Sampling data
                batch_data = next(self.replay_iter)
                if self.cfg.train_world_model:
                    state, outputs, metrics = self.agent.update_wm(batch_data, self.global_step)
                else:
                    with torch.no_grad():
                        outputs, metrics = self.agent.wm.observe_data(batch_data,)
                if self.cfg.train_connector:
                    _, metrics = self.agent.wm.update_additional_detached_modules(batch_data, outputs, metrics)
            else:
                imag_warmup_steps = self.cfg.imag_warmup_steps
                metrics, batch_data = {}, None
                with torch.no_grad():
                    # fake actions 
                    mix = self.cfg.mix_random_actions
                    random = False
                    # num warmup steps

                    if mix:
                        init = self.agent.wm.rssm.initial(self.cfg.batch_size * (self.cfg.batch_length // 2))
                    else:                        
                        init = self.agent.wm.rssm.initial(self.cfg.batch_size * self.cfg.batch_length)


                    unif_dist = self.agent.wm.rssm.get_unif_dist(init)
                    if 'logit' in init:
                        init['logit'] = unif_dist.mean
                    else:
                        init['mean'] = unif_dist.mean 
                        init['std'] = unif_dist.std 
                    init['stoch'] = unif_dist.sample()
                    
                    if self.cfg.start_from_video in [True, 'mix']:
                        T = self.agent.wm.connector.n_frames * 2 # should this be an hyperparam?
                        B = init['deter'].shape[0] // T 
                        text_feat_dim = self.agent.wm.connector.viclip_emb_dim
                        video_embed = torch.randn((B, T, text_feat_dim), device=self.agent.device)
                        video_embed = torch.nn.functional.normalize(video_embed, dim=-1)
                        # Get initial state
                        video_init = self.agent.wm.connector.video_imagine(video_embed, dreamer_init=None, sample=True, reset_every_n_frames=False, denoise=True)
                        video_init = { k : v.reshape(B * T, *v.shape[2:]) for k, v in video_init.items()}
                        if self.cfg.start_from_video == 'mix':
                            probs = torch.rand((B * T, 1,1), device=init['stoch'].device) > 0.5  # should this be an hyperparam?
                            init['stoch'] = (probs * init['stoch']) + ( (~probs) * video_init['stoch'] )
                        else:
                            init['stoch'] = video_init['stoch']
                    
                    if random:
                        fake_action = torch.rand(self.cfg.batch_size * self.cfg.batch_length, imag_warmup_steps, self.agent.act_dim, device=self.agent.device) * 2 - 1
                        post = self.agent.wm.rssm.imagine(fake_action, init, sample=True)
                        post = { k : v[:, -1].reshape([self.cfg.batch_size, self.cfg.batch_length, ] + list(v.shape[2:])) for k,v in post.items() }
                    elif mix:
                        fake_action = torch.rand(self.cfg.batch_size * self.cfg.batch_length // 2, imag_warmup_steps, self.agent.act_dim, device=self.agent.device) * 2 - 1
                        post1 = self.agent.wm.rssm.imagine(fake_action, init, sample=True)
                        post1 = { k : v[:, -1].reshape([self.cfg.batch_size, self.cfg.batch_length // 2, ] + list(v.shape[2:])) for k,v in post1.items() }

                        init2 = { k : v.reshape([self.cfg.batch_size, self.cfg.batch_length // 2, ] + list(v.shape[1:])) for k,v in init.items() }
                        post2 = self.agent.wm.imagine(self.agent._imag_behavior.actor, init2, None, imag_warmup_steps) 
                        post2 = { k : v[-1, :].reshape([self.cfg.batch_size, self.cfg.batch_length // 2, ] + list(v.shape[2:])) for k,v in post2.items() }
                        post = { k: torch.cat([post1[k], post2[k]], dim=1) for k in post1 }
                    else:
                        init = { k : v.reshape([self.cfg.batch_size, self.cfg.batch_length, ] + list(v.shape[1:])) for k,v in init.items() }
                        post = self.agent.wm.imagine(self.agent._imag_behavior.actor, init, None, imag_warmup_steps) 
                        post = { k : v[-1, :].reshape([self.cfg.batch_size, self.cfg.batch_length, ] + list(v.shape[2:])) for k,v in post.items() }

                is_terminal = torch.zeros(self.cfg.batch_size, self.cfg.batch_length, device=self.agent.device)
                outputs = dict(post=post, is_terminal=is_terminal)
            if getattr(self.cfg.agent, 'imag_reward_fn', None) is not None:
                metrics.update(self.agent.update_imag_behavior(state=None, outputs=outputs, metrics=metrics, seq_data=batch_data,)[1])

            if self.global_step > 0:
                if should_log_scalars(self.global_step):
                    if hasattr(self, 'replay_storage'):
                        metrics.update(self.replay_storage.stats)
                    self.logger.log_metrics(metrics, self.global_frame, ty='train')
                if should_log_visual(self.global_step) and self.cfg.train_from_data and hasattr(self.agent, 'report'):
                    with torch.no_grad(), utils.eval_mode(self.agent):
                        videos = self.agent.report(next(self.replay_iter))
                        self.logger.log_visual(videos, self.global_frame)
                if should_log_scalars(self.global_step):
                    elapsed_time, total_time = self.timer.reset()
                    with self.logger.log_and_dump_ctx(self.global_frame, ty='train') as log:
                        log('fps', self.cfg.log_every_frames / elapsed_time)
                        log('step', self.global_step)
                        if 'model_loss' in metrics: 
                            log('episode_reward', metrics['model_loss'].item())
                    
                # save last model
                if should_save_model(self.global_step):
                    self.save_last_model()

            self._global_step += 1
            # == 1000 is to make sure everything is going well since the start
            if (self.global_frame == 1000) or (self.global_frame % self.cfg.snapshot_every_frames == 0):
                self.save_snapshot()

    @utils.retry
    def save_snapshot(self):
        snapshot = self.root_dir / f'snapshot_{self.global_frame}.pt'
        keys_to_save = ['agent', '_global_step', '_global_episode']
        payload = {k: self.__dict__[k] for k in keys_to_save}
        with snapshot.open('wb') as f:
            torch.save(payload, f)

    def setup_wandb(self):
        cfg = self.cfg
        exp_name = '_'.join([
            cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type,
            str(cfg.seed)
        ])
        wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name)
        flat_cfg = utils.flatten_dict(cfg)
        wandb.config.update(flat_cfg)
        self.wandb_run_id = wandb.run.id

    @utils.retry
    def save_last_model(self):
        snapshot = self.root_dir / 'last_snapshot.pt'
        if snapshot.is_file():
            temp = Path(str(snapshot).replace("last_snapshot.pt", "second_last_snapshot.pt"))
            os.replace(snapshot, temp)
        keys_to_save = ['agent', '_global_step', '_global_episode']
        if self.cfg.use_wandb: 
            keys_to_save.append('wandb_run_id')
        payload = {k: self.__dict__[k] for k in keys_to_save}
        with snapshot.open('wb') as f:
            torch.save(payload, f)

    @utils.retry
    def load_snapshot(self, snapshot_dir, resume=True):
        print('Loading snapshot from: ', str(snapshot_dir))
        try:
            snapshot = snapshot_dir / 'last_snapshot.pt' if resume else snapshot_dir
            with snapshot.open('rb') as f:
                payload = torch.load(f)
        except:
            snapshot = Path(str(snapshot_dir).replace('last_snapshot', 'second_last_snapshot'))
            with snapshot.open('rb') as f:
                payload = torch.load(f)
        if type(payload) != dict:
            self.agent = payload
            self.agent.requires_grad_(requires_grad=False)
            return
        for k,v in payload.items():
            setattr(self, k, v)
            if k == 'wandb_run_id' and resume:
                assert wandb.run is None
                cfg = self.cfg
                exp_name = '_'.join([
                    cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type,
                    str(cfg.seed)
                ])
                wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name, id=v, resume="must")

    def get_snapshot_dir(self):
        snap_dir = self.cfg.snapshot_dir
        snapshot_dir = self.workdir / Path(snap_dir)
        snapshot_dir.mkdir(exist_ok=True, parents=True)
        return snapshot_dir 

def start_training(cfg, savedir, workdir):
    from train import Workspace as W
    root_dir = Path.cwd()
    cfg.workdir = str(root_dir)
    workspace = W(cfg, savedir, workdir)
    workspace.root_dir = root_dir
    snapshot = workspace.root_dir / 'last_snapshot.pt'
    if snapshot.exists():
        print(f'resuming: {snapshot}')
        workspace.load_snapshot(workspace.root_dir)
    if cfg.use_wandb and wandb.run is None:
        # otherwise it was resumed
        workspace.setup_wandb()
    workspace.train()

@hydra.main(config_path='.', config_name='train')
def main(cfg):
    start_training(cfg, None, None)

if __name__ == '__main__':
    main()