from typing import Any, Dict
from schema import Schema, Or
import schema
from data import Scenario, MergedDataset
from methods.base.alg import BaseAlg
from data import build_dataloader
from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel
from ...model.base import ElasticDNNUtil
import torch.optim
import tqdm
import torch.nn.functional as F
from torch import nn
from utils.dl.common.env import create_tbwriter
import os
import random
import numpy as np
from copy import deepcopy
from utils.dl.common.model import LayerActivation, get_module
from utils.common.log import logger


class ElasticDNN_MDPretrainingAlg(BaseAlg):
    """
    construct indexes between a filter/row of MD and all filters/rows of FM in the same layer
    too huge indexes (~1GB), train so slow, hard to optimize
    """
    def get_required_models_schema(self) -> Schema:
        return Schema({
            'fm': ElasticDNN_OfflineFMModel,
            'md': ElasticDNN_OfflineMDModel
        })
        
    def get_required_hyp_schema(self) -> Schema:
        return Schema({
            'launch_tbboard': bool,
            
            'samples_size': (int, int, int, int),
            'generate_md_width_ratio': int,
            
            'FBS_r': int,
            'FBS_ignore_layers': [str],
            
            'train_batch_size': int,
            'val_batch_size': int,
            'num_workers': int,            
            'optimizer': str,
            'md_optimizer_args': dict,
            'indexes_optimizer_args': dict,
            'scheduler': str,
            'scheduler_args': dict,
            'num_iters': int,
            'val_freq': int,
            'max_sparsity': float,
            'min_sparsity': float,
            'distill_loss_weight': float,
            'index_loss_weight': float,
            'val_num_sparsities': int,
            
            'bn_cal_num_iters': int,
            
            'index_guided_linear_comb_split_size': Or(int, None)
        })
        
    def upsample_2d_tensor(self, p: torch.Tensor, target_len: int):
        assert p.dim() == 2 # regard 2d weight as (batch_size, 1d_vector_dim)
        return F.upsample(p.unsqueeze(1).unsqueeze(3),
                          size=(target_len, 1),
                          mode='bilinear').squeeze(3).squeeze(1)
        
    def two_params_diff_fast(self, trained_p: torch.Tensor, ref_p: torch.Tensor, 
                             index: torch.Tensor, 
                             split_size: int):

        assert trained_p.dim() == ref_p.dim()
        assert index.size(0) == trained_p.size(0) and index.size(1) == ref_p.size(0)
        
        # print(trained_p.size(), ref_p.size(), index.size())

        ref_p = ref_p.detach()
        if trained_p.dim() > 1:
            trained_p = trained_p.flatten(1)
            ref_p = ref_p.flatten(1)
            
            # the weight size of master DNN and foundation model may be totally different
            
            # MD -> FM: upsample first
            # FM -> MD: downsample first
            if trained_p.size(1) < ref_p.size(1):
                trained_p = self.upsample_2d_tensor(trained_p, ref_p.size(1))
            
            index = index.unsqueeze(-1)
        #     linear_combed_ref_p = (ref_p.unsqueeze(0) * index).sum(1)
        # else:
        
        # print(trained_p.size(), ref_p.size(), index.size())
        
        if split_size is None:
            # old version: huge memory consumption, not recommended (although this is fastest)
            # print('old version')
            linear_combed_ref_p = (ref_p.unsqueeze(0) * index).sum(1)
        
        else:
            # new version
            linear_combed_ref_p = 0
            
            cur_split_size = split_size
            while index.size(1) % cur_split_size != 0:
                cur_split_size -= 1
            # print(cur_split_size) 
            
            for i in range(0, index.size(1), cur_split_size):
                # if not isinstance(linear_combed_ref_p, int):
                    # print(linear_combed_ref_p.size(), ref_p.unsqueeze(0)[:, i: i + cur_split_size].size(), index[:, i: i + cur_split_size].size())
                linear_combed_ref_p += ref_p.unsqueeze(0)[:, i: i + cur_split_size] * index[:, i: i + cur_split_size]
            linear_combed_ref_p = linear_combed_ref_p.sum(1)
            
        diff = (linear_combed_ref_p - trained_p).norm(2) ** 2
        return diff
        
    def get_index_loss(self, fm, md, indexes, match_fn, split_size):
        res = 0.

        for name, p in md.named_parameters():
            if p.dim() == 0:
                continue
            
            raw_p = match_fn(name, fm)
            if raw_p is None:
                continue

            index = indexes[name]
            
            # print(name)
            res += self.two_params_diff_fast(p, raw_p, index, split_size)
        return res
    
    def bn_cal(self, model: nn.Module, train_loader, num_iters, device):
        has_bn = False
        for n, m in model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                has_bn = True
                break
        
        if not has_bn:
            return {}
        
        def bn_calibration_init(m):
            """ calculating post-statistics of batch normalization """
            if getattr(m, 'track_running_stats', False):
                # reset all values for post-statistics
                m.reset_running_stats()
                # set bn in training mode to update post-statistics
                m.training = True
                
        with torch.no_grad():
            model.eval()
            model.apply(bn_calibration_init)
            for _ in range(num_iters):
                x, _ = next(train_loader)
                model(x.to(device))
            model.eval()
            
        bn_stats = {}
        for n, m in model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                bn_stats[n] = m
        return bn_stats
        
    def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]:
        super().run(scenario, hyps)
        
        # sanity check
        # a=  torch.tensor([[1, 2, 3], [1, 2, 4]])
        # index = torch.tensor([[1, 2, 3],
        # [1, 2, 4]])
        # b = torch.tensor([[1, 2, 3], [1, 2, 4], [2, 3, 4]])
        # print(self.two_params_diff_fast(a, b, index, hyps['index_guided_linear_comb_split_size']))
        
        assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) # for auto completion
        assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion
        
        # 1. add FBS
        device = self.models['md'].device
        
        logger.info(f'init master DNN by reducing width of an adapted foundation model (already tuned by LoRA)...')
        
        before_fm_model = deepcopy(self.models['fm'].models_dict['main'])
        lora_util = self.models['fm'].get_lora_util()
        lora_absorbed_fm_model = lora_util.absorb_lora_and_recover_net_structure(self.models['fm'].models_dict['main'], 
                                                                                 torch.rand(hyps['samples_size']).to(device))
        self.models['fm'].models_dict['main'] = lora_absorbed_fm_model
        master_dnn = self.models['fm'].generate_md_by_reducing_width(hyps['generate_md_width_ratio'], 
                                                                     torch.rand(hyps['samples_size']).to(device))
        self.models['fm'].models_dict['main'] = before_fm_model
        
        elastic_dnn_util = self.models['fm'].get_elastic_dnn_util()
        master_dnn = elastic_dnn_util.convert_raw_dnn_to_master_dnn_with_perf_test(master_dnn,
                                                                                   hyps['FBS_r'], hyps['FBS_ignore_layers'])
        self.models['md'].models_dict['main'] = master_dnn
        self.models['md'].to(device)
        
        # 2. train (knowledge distillation, index relationship)
        offline_datasets = scenario.get_offline_datasets()
        train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()])
        val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()])
        train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'],
                                        True, None))
        val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'],
                                      False, False)
        
        # 2.1 train only FBS (skipped because current md cannot do proper inference)
            
        # 2.2 train whole master DNN (knowledge distillation, index relationship)
        for p in master_dnn.parameters():
            p.requires_grad = True
        self.models['md'].to_train_mode()
        
        indexes = {}
        for name, p in self.models['md'].models_dict['main'].named_parameters():
            if p.dim() > 0:
                matched_p_in_fm = self.models['md'].get_matched_param_of_fm(name, self.models['fm'].models_dict['main'])
                if matched_p_in_fm is None:
                    continue
                indexes[name] = torch.zeros((p.size(0), matched_p_in_fm.size(0))).to(device)
                indexes[name].requires_grad = True
        
        tmp_indexes_file_path = os.path.join(self.res_save_dir, 'tmp-indexes.pt')
        torch.save(indexes, tmp_indexes_file_path)
        logger.info(f'generate indexes ({(os.path.getsize(tmp_indexes_file_path) / 1024**2):.3f}MB)')
        os.remove(tmp_indexes_file_path)
        
        optimizer = torch.optim.__dict__[hyps['optimizer']]([
            {'params': self.models['md'].models_dict['main'].parameters(), **hyps['md_optimizer_args']},
            {'params': [v for v in indexes.values()], **hyps['indexes_optimizer_args']}
        ])
        scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args'])
        tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard'])
        pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True)
        best_avg_val_acc = 0.
        
        md_output_hook = None
        
        for iter_index in pbar:
            self.models['md'].to_train_mode()
            self.models['fm'].to_eval_mode()
            
            rand_sparsity = random.random() * (hyps['max_sparsity'] - hyps['min_sparsity']) + hyps['min_sparsity']
            elastic_dnn_util.set_master_dnn_sparsity(self.models['md'].models_dict['main'], rand_sparsity)
            
            x, y = next(train_loader)
            x, y = x.to(device), y.to(device)
            
            with torch.no_grad():
                fm_output = self.models['fm'].infer(x)

            if md_output_hook is None:
                md_output_hook = LayerActivation(self.models['md'].models_dict['main'], False, device)
            task_loss = self.models['md'].forward_to_get_task_loss(x, y)
            md_output = md_output_hook.output
            distill_loss = hyps['distill_loss_weight'] * self.models['md'].get_distill_loss(md_output, fm_output)
            index_loss = hyps['index_loss_weight'] * self.get_index_loss(self.models['fm'].models_dict['main'], 
                                                                         self.models['md'].models_dict['main'], 
                                                                         indexes,
                                                                         self.models['md'].get_matched_param_of_fm,
                                                                         hyps['index_guided_linear_comb_split_size'])
            total_loss = task_loss + distill_loss + index_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            scheduler.step()
            
            if (iter_index + 1) % hyps['val_freq'] == 0:
                
                elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main'])
                md_output_hook.remove()
                md_output_hook = None
                
                cur_md = self.models['md'].models_dict['main']
                md_for_test = deepcopy(self.models['md'].models_dict['main'])
                val_accs = {}
                avg_val_acc = 0.
                bn_stats = {}
                
                for val_sparsity in np.linspace(hyps['min_sparsity'], hyps['max_sparsity'], num=hyps['val_num_sparsities']):
                    elastic_dnn_util.set_master_dnn_sparsity(md_for_test, val_sparsity)
                    bn_stats[f'{val_sparsity:.4f}'] = self.bn_cal(md_for_test, train_loader, hyps['bn_cal_num_iters'], device)
                    self.models['md'].models_dict['main'] = md_for_test
                    self.models['md'].to_eval_mode()
                    val_acc = self.models['md'].get_accuracy(val_loader)
                    
                    val_accs[f'{val_sparsity:.4f}'] = val_acc
                    avg_val_acc += val_acc
                    
                avg_val_acc /= hyps['val_num_sparsities']
                
                self.models['md'].models_dict['main'] = cur_md
                self.models['md'].models_dict['indexes'] = indexes
                self.models['md'].models_dict['bn_stats'] = bn_stats
                self.models['fm'].models_dict['indexes'] = indexes
                
                self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt'))
                self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt'))
                
                if avg_val_acc > best_avg_val_acc:
                    best_avg_val_acc = avg_val_acc
                    self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt'))
                    self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt'))
                
            tb_writer.add_scalars(f'losses', dict(task=task_loss, distill=distill_loss, index=index_loss, total=total_loss), iter_index)
            pbar.set_description(f'loss: {total_loss:.6f}')
            if (iter_index + 1) >= hyps['val_freq']:
                tb_writer.add_scalars(f'accs/val_accs', val_accs, iter_index)
                tb_writer.add_scalar(f'accs/avg_val_acc', avg_val_acc, iter_index)
                pbar.set_description(f'loss: {total_loss:.6f}, avg_val_acc: {avg_val_acc:.4f}')