import logging
import math
from collections import OrderedDict

import mmcv
import numpy as np
import torch
from torchvision.utils import save_image

from models.archs.fcn_arch import FCNHead
from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
from models.archs.unet_arch import ShapeUNet
from models.losses.accuracy import accuracy
from models.losses.cross_entropy_loss import CrossEntropyLoss

logger = logging.getLogger('base')


class ParsingGenModel():
    """Paring Generation model.
    """

    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device('cuda')
        self.is_train = opt['is_train']

        self.attr_embedder = ShapeAttrEmbedding(
            dim=opt['embedder_dim'],
            out_dim=opt['embedder_out_dim'],
            cls_num_list=opt['attr_class_num']).to(self.device)
        self.parsing_encoder = ShapeUNet(
            in_channels=opt['encoder_in_channels']).to(self.device)
        self.parsing_decoder = FCNHead(
            in_channels=opt['fc_in_channels'],
            in_index=opt['fc_in_index'],
            channels=opt['fc_channels'],
            num_convs=opt['fc_num_convs'],
            concat_input=opt['fc_concat_input'],
            dropout_ratio=opt['fc_dropout_ratio'],
            num_classes=opt['fc_num_classes'],
            align_corners=opt['fc_align_corners'],
        ).to(self.device)

        self.init_training_settings()

        self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
                        [250, 235, 215], [255, 250, 205], [211, 211, 211],
                        [70, 130, 180], [127, 255, 212], [0, 100, 0],
                        [50, 205, 50], [255, 255, 0], [245, 222, 179],
                        [255, 140, 0], [255, 0, 0], [16, 78, 139],
                        [144, 238, 144], [50, 205, 174], [50, 155, 250],
                        [160, 140, 88], [213, 140, 88], [90, 140, 90],
                        [185, 210, 205], [130, 165, 180], [225, 141, 151]]

    def init_training_settings(self):
        optim_params = []
        for v in self.attr_embedder.parameters():
            if v.requires_grad:
                optim_params.append(v)
        for v in self.parsing_encoder.parameters():
            if v.requires_grad:
                optim_params.append(v)
        for v in self.parsing_decoder.parameters():
            if v.requires_grad:
                optim_params.append(v)
        # set up optimizers
        self.optimizer = torch.optim.Adam(
            optim_params,
            self.opt['lr'],
            weight_decay=self.opt['weight_decay'])
        self.log_dict = OrderedDict()
        self.entropy_loss = CrossEntropyLoss().to(self.device)

    def feed_data(self, data):
        self.pose = data['densepose'].to(self.device)
        self.attr = data['attr'].to(self.device)
        self.segm = data['segm'].to(self.device)

    def optimize_parameters(self):
        self.attr_embedder.train()
        self.parsing_encoder.train()
        self.parsing_decoder.train()

        self.attr_embedding = self.attr_embedder(self.attr)
        self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding)
        self.seg_logits = self.parsing_decoder(self.pose_enc)

        loss = self.entropy_loss(self.seg_logits, self.segm)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.log_dict['loss_total'] = loss

    def get_vis(self, save_path):
        img_cat = torch.cat([
            self.pose,
            self.segm,
        ], dim=3).detach()
        img_cat = ((img_cat + 1) / 2)

        img_cat = img_cat.clamp_(0, 1)

        save_image(img_cat, save_path, nrow=1, padding=4)

    def inference(self, data_loader, save_dir):
        self.attr_embedder.eval()
        self.parsing_encoder.eval()
        self.parsing_decoder.eval()

        acc = 0
        num = 0

        for _, data in enumerate(data_loader):
            pose = data['densepose'].to(self.device)
            attr = data['attr'].to(self.device)
            segm = data['segm'].to(self.device)
            img_name = data['img_name']

            num += pose.size(0)
            with torch.no_grad():
                attr_embedding = self.attr_embedder(attr)
                pose_enc = self.parsing_encoder(pose, attr_embedding)
                seg_logits = self.parsing_decoder(pose_enc)
            seg_pred = seg_logits.argmax(dim=1)
            acc += accuracy(seg_logits, segm)
            palette_label = self.palette_result(segm.cpu().numpy())
            palette_pred = self.palette_result(seg_pred.cpu().numpy())
            pose_numpy = ((pose[0] + 1) / 2. * 255.).expand(
                3,
                pose[0].size(1),
                pose[0].size(2),
            ).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
            concat_result = np.concatenate(
                (pose_numpy, palette_pred, palette_label), axis=1)
            mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}')

        self.attr_embedder.train()
        self.parsing_encoder.train()
        self.parsing_decoder.train()
        return (acc / num).item()

    def get_current_log(self):
        return self.log_dict

    def update_learning_rate(self, epoch):
        """Update learning rate.

        Args:
            current_iter (int): Current iteration.
            warmup_iter (int): Warmup iter numbers. -1 for no warmup.
                Default: -1.
        """
        lr = self.optimizer.param_groups[0]['lr']

        if self.opt['lr_decay'] == 'step':
            lr = self.opt['lr'] * (
                self.opt['gamma']**(epoch // self.opt['step']))
        elif self.opt['lr_decay'] == 'cos':
            lr = self.opt['lr'] * (
                1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
        elif self.opt['lr_decay'] == 'linear':
            lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
        elif self.opt['lr_decay'] == 'linear2exp':
            if epoch < self.opt['turning_point'] + 1:
                # learning rate decay as 95%
                # at the turning point (1 / 95% = 1.0526)
                lr = self.opt['lr'] * (
                    1 - epoch / int(self.opt['turning_point'] * 1.0526))
            else:
                lr *= self.opt['gamma']
        elif self.opt['lr_decay'] == 'schedule':
            if epoch in self.opt['schedule']:
                lr *= self.opt['gamma']
        else:
            raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
        # set learning rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        return lr

    def save_network(self, save_path):
        """Save networks.
        """

        save_dict = {}
        save_dict['embedder'] = self.attr_embedder.state_dict()
        save_dict['encoder'] = self.parsing_encoder.state_dict()
        save_dict['decoder'] = self.parsing_decoder.state_dict()

        torch.save(save_dict, save_path)

    def load_network(self):
        checkpoint = torch.load(self.opt['pretrained_parsing_gen'])

        self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True)
        self.attr_embedder.eval()

        self.parsing_encoder.load_state_dict(
            checkpoint['encoder'], strict=True)
        self.parsing_encoder.eval()

        self.parsing_decoder.load_state_dict(
            checkpoint['decoder'], strict=True)
        self.parsing_decoder.eval()

    def palette_result(self, result):
        seg = result[0]
        palette = np.array(self.palette)
        assert palette.shape[1] == 3
        assert len(palette.shape) == 2
        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
        for label, color in enumerate(palette):
            color_seg[seg == label, :] = color
        # convert to BGR
        color_seg = color_seg[..., ::-1]
        return color_seg