import torch
from torch import nn
from abc import ABC, abstractmethod

from utils.dl.common.model import get_model_device, get_model_latency, get_model_size
from utils.common.log import logger


class KTakesAll(nn.Module):
    # k means sparsity (the larger k is, the smaller model is)
    def __init__(self, k):
        super(KTakesAll, self).__init__()
        self.k = k
        self.cached_i = None
        
    def forward(self, g: torch.Tensor):
        # k = int(g.size(1) * self.k)
        # i = (-g).topk(k, 1)[1]
        # t = g.scatter(1, i, 0)
        
        k = int(g.size(-1) * self.k)
        i = (-g).topk(k, -1)[1]
        self.cached_i = i
        t = g.scatter(-1, i, 0)
        
        return t
    
    
class Abs(nn.Module):
    def __init__(self):
        super(Abs, self).__init__()
        
    def forward(self, x):
        return x.abs()


class Layer_WrappedWithFBS(nn.Module):
    def __init__(self):
        super(Layer_WrappedWithFBS, self).__init__()
        
        init_sparsity = 0.5
        self.k_takes_all = KTakesAll(init_sparsity)
        
        self.cached_raw_channel_attention = None
        self.cached_channel_attention = None
        self.use_cached_channel_attention = False


class ElasticDNNUtil(ABC):
    @abstractmethod
    def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]):
        raise NotImplementedError
    
    def convert_raw_dnn_to_master_dnn_with_perf_test(self, raw_dnn: nn.Module, r: float, ignore_layers=[]):
        raw_dnn_size = get_model_size(raw_dnn, True)
        master_dnn = self.convert_raw_dnn_to_master_dnn(raw_dnn, r, ignore_layers)
        master_dnn_size = get_model_size(master_dnn, True)
        
        logger.info(f'master DNN w/o FBS ({raw_dnn_size:.3f}MB) -> master DNN w/ FBS ({master_dnn_size:.3f}MB) '
                    f'(↑ {(((master_dnn_size - raw_dnn_size) / raw_dnn_size) * 100.):.2f}%)')
        return master_dnn
    
    def set_master_dnn_inference_via_cached_channel_attention(self, master_dnn: nn.Module):
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                assert module.cached_channel_attention is not None
                module.use_cached_channel_attention = True
    
    def set_master_dnn_dynamic_inference(self, master_dnn: nn.Module):
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                module.cached_channel_attention = None
                module.use_cached_channel_attention = False
    
    def train_only_fbs_of_master_dnn(self, master_dnn: nn.Module):
        fbs_params = []
        for n, p in master_dnn.named_parameters():
            if '.fbs' in n:
                fbs_params += [p]
                p.requires_grad = True
            else:
                p.requires_grad = False
        return fbs_params
    
    def get_accu_l1_reg_of_raw_channel_attention_in_master_dnn(self, master_dnn: nn.Module):
        res = 0.
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                res += module.cached_raw_channel_attention.norm(1)
        return res
    
    def get_raw_channel_attention_in_master_dnn(self, master_dnn: nn.Module):
        res = {}
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                res[name] = module.cached_raw_channel_attention
        return res

    def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float):
        assert 0 <= sparsity <= 1., sparsity
        for name, module in master_dnn.named_modules():
            if isinstance(module, KTakesAll):
                module.k = sparsity
        logger.debug(f'set master DNN sparsity to {sparsity}')
        
    def clear_cached_channel_attention_in_master_dnn(self, master_dnn: nn.Module):
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                module.cached_raw_channel_attention = None
                module.cached_channel_attention = None
                
    @abstractmethod
    def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor):
        raise NotImplementedError
    
    @abstractmethod
    def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False):
        raise NotImplementedError
    
    def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False):
        master_dnn_size = get_model_size(master_dnn, True)
        master_dnn_latency = get_model_latency(master_dnn, (1, *list(samples.size())[1:]), 50, 
                                               get_model_device(master_dnn), 50, False)
        
        res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail)
        if not return_detail:
            surrogate_dnn = res
        else:
            surrogate_dnn, unpruned_indexes_of_layers = res
        surrogate_dnn_size = get_model_size(surrogate_dnn, True)
        surrogate_dnn_latency = get_model_latency(surrogate_dnn, (1, *list(samples.size())[1:]), 50, 
                                                  get_model_device(surrogate_dnn), 50, False)

        logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> '
                    f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n'
                    f'(model size: ↓ {(master_dnn_size / surrogate_dnn_size):.2f}x, '
                    f'latency: ↓ {(master_dnn_latency / surrogate_dnn_latency):.2f}x)')
        
        return res