import datetime
import os
import time

import torch
import torch.utils.data
from torch import nn

from functools import reduce
import operator
from bert.modeling_bert import BertModel

import torchvision
from lib import segmentation

import transforms as T
import utils
import numpy as np

import torch.nn.functional as F

import gc
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter


def get_dataset(image_set, transform, args):
    from data.dataset_refer_bert_rev import ReferDataset
    ds = ReferDataset(args,
                      split=image_set,
                      image_transforms=transform,
                      target_transforms=None,
                      eval_mode=image_set == 'val'
                      )
    num_classes = 2

    return ds, num_classes


# IoU calculation for validation
def IoU(pred, gt):
    pred = pred.argmax(1)

    intersection = torch.sum(torch.mul(pred, gt))
    union = torch.sum(torch.add(pred, gt)) - intersection

    if intersection == 0 or union == 0:
        iou = 0
    else:
        iou = float(intersection) / float(union)

    return iou, intersection, union


def get_transform(args):
    transforms = [T.Resize(args.img_size, args.img_size),
                  T.ToTensor(),
                  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                  ]

    return T.Compose(transforms)


def criterion(input, target):
    weight = torch.FloatTensor([0.9, 1.1]).cuda()
    return nn.functional.cross_entropy(input, target, weight=weight)

# def return_mask(metric_mask, B_):
#     negative_mask = None

#     sim_mask = torch.zeros(B_, B_, device=metric_mask.device)
#     n_pos = B_//2

#     sim_mask.fill_diagonal_(1)  # Set diagonal elements to 1 for all cases
#     sim_mask.diagonal(offset=n_pos).fill_(1)
#     sim_mask.diagonal(offset=-n_pos).fill_(1)
        

#     return sim_mask, negative_mask


def return_mask(emb_distance, verb_mask=None):
    B_, B_ = emb_distance.shape
    positive_mask = torch.zeros_like(emb_distance)
    positive_mask.fill_diagonal_(1)  # Set diagonal elements to 1 for all cases
    
    if B_ < len(verb_mask):
        # If B_ equals to 2*K (double the number of verb phrase)
        for i in range(B_ // 2):
            positive_mask[2 * i, 2 * i + 1] = 1
            positive_mask[2 * i + 1, 2 * i] = 1
    else:
        # Process the case where we have a mix of sentences with and without verbs
        i = 0
        while i < B_:
            if verb_mask[i] == 1:
                positive_mask[i, i + 1] = 1
                positive_mask[i + 1, i] = 1
                i += 2
            else:
                i += 1  
    negative_mask = torch.ones_like(emb_distance) - positive_mask
    return positive_mask, negative_mask


def UniAngularContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
    _, C, H, W = total_fq.shape
    
    if verbonly :
        B = total_fq[verb_mask].shape[0]
        emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C)
        assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
    else :
        emb = torch.mean(total_fq, dim=-1)

    B_ = emb.shape[0]
    emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C) 
    emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
    sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
    sim_matrix = sim(emb_i, emb_j).reshape(B_, B_)  # (B_, B_)
    sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
    
    positive_mask, negative_mask = return_mask(sim_matrix, verb_mask)
    if len(positive_mask) > 0 : 
        sim_matrix_with_margin = sim_matrix.clone()
        sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)        

        logits = sim_matrix_with_margin / tau
        exp_logits = torch.exp(logits)
        pos_exp_logits = exp_logits * positive_mask.long()
        pos_exp_logits = pos_exp_logits.sum(dim=-1)

        # print("pos_exp_logits: ", pos_exp_logits.shape)
        total_exp_logits = exp_logits.sum(dim=-1)
        positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
        angular_loss = positive_loss.mean()

        return angular_loss
    else :
        return torch.tensor(0.0, device=total_fq.device)



def UniAngularLogitContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):        
    epsilon = 1e-10  # Stability term for numerical issues
    _, C, H, W = total_fq.shape

    # Calculate embeddings
    if verbonly :
        B = total_fq[verb_mask].shape[0]
        emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C)
        assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
    else :
        emb = torch.mean(total_fq, dim=-1)

    B_ = emb.shape[0]
    emb_i = emb.unsqueeze(1).repeat(1, B_, 1)  # (B_, B_, C)
    emb_j = emb.unsqueeze(0).repeat(B_, 1, 1)  # (B_, B_, C)

    sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
    sim_matrix = sim(emb_i, emb_j).reshape(B_, B_)  # (B_, B_)
    sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)

    margin_in_radians = m / 57.2958  # Convert degrees to radians
    theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix)
    positive_mask, negative_mask = return_mask(sim_matrix, verb_mask)

    theta_with_margin = theta_matrix.clone()
    theta_with_margin[positive_mask.bool()] -= margin_in_radians  # Subtract margin directly for positives

    logits = theta_with_margin / tau  # Scale with temperature

    # Compute exponential logits for softmax
    exp_logits = torch.exp(logits)
    # pos_exp_logits = (exp_logits * positive_mask).sum(dim=-1)  # Positive term
    pos_exp_logits = exp_logits * positive_mask
    pos_exp_logits = pos_exp_logits.sum(dim=-1)

    # neg_exp_logits = (exp_logits * negative_mask).sum(dim=-1)  # Negative term
    # total_exp_logits = pos_exp_logits + neg_exp_logits
    total_exp_logits = exp_logits.sum(dim=-1)

    # pos_exp_logits = pos_exp_logits + epsilon
    # total_exp_logits = total_exp_logits + epsilon

    # Compute angular loss
    loss = -torch.log(pos_exp_logits / total_exp_logits)
    angular_loss = loss.mean()

    return angular_loss

# def UniAngularContrastLoss(samples_with_pos, metric_mask, m=0.5, tau=0.05, verb_mask=None, verbonly=True, args=None):
#     B_, C, H, W = samples_with_pos.shape
    
#     emb = torch.mean(samples_with_pos, dim=(-1, -2)).reshape(B_, C)
#     if len(emb) > 0: 
#         sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
#         sim_matrix = sim(emb.unsqueeze(1), emb.unsqueeze(0))  # (B_, B_)
#         sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
        
#         sim_mask, negative_mask = return_mask(emb, B_)
#         assert sim_mask.shape == sim_matrix.shape, f"sim_mask shape {sim_mask.shape} is not equal to sim_matrix shape {sim_matrix.shape}."

#         # Apply margin to positive pairs
#         sim_matrix_with_margin = sim_matrix.clone()
#         sim_matrix_with_margin[sim_mask.bool()] = torch.cos(torch.acos(sim_matrix[sim_mask.bool()]) + m / 57.2958)        

#         # Scale logits with temperature
#         logits = sim_matrix_with_margin / tau

#         # Compute the softmax loss for all pairs
#         exp_logits = torch.exp(logits)
#         # print("exp_logits: ", exp_logits.shape)
#         pos_exp_logits = exp_logits * sim_mask.long()
#         pos_exp_logits = pos_exp_logits.sum(dim=-1)
#         #print("pos_exp_logits: ", pos_exp_logits.shape)
#         # print("pos_exp_logits: ", pos_exp_logits.shape)
#         total_exp_logits = exp_logits.sum(dim=-1)
#         positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
#         angular_loss = positive_loss.mean()

#         return angular_loss
#     else :
#         return torch.tensor(0.0, device=samples_with_pos.device)


def evaluate(model, data_loader, bert_model):
    #print("current model : ", model)
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    total_its = 0
    acc_ious = 0

    # evaluation variables
    cum_I, cum_U = 0, 0
    eval_seg_iou_list = [.5, .6, .7, .8, .9]
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    seg_total = 0
    mean_IoU = []

    with torch.no_grad():
        for data in metric_logger.log_every(data_loader, 100, header):
            total_its += 1
            image, target, sentences, attentions = data
            image, target, sentences, attentions = \
                                                    image.cuda(non_blocking=True),\
                                                    target.cuda(non_blocking=True),\
                                                    sentences.cuda(non_blocking=True),\
                                                    attentions.cuda(non_blocking=True)

            sentences = sentences.squeeze(1)
            attentions = attentions.squeeze(1)

            if bert_model is not None:
                last_hidden_states = bert_model(sentences, attention_mask=attentions)[0]
                embedding = last_hidden_states.permute(0, 2, 1)  # (B, 768, N_l) to make Conv1d happy
                attentions = attentions.unsqueeze(dim=-1)  # (B, N_l, 1)
                output = model(image, embedding, l_mask=attentions, is_train=False)
            else:
                sentences = sentences.squeeze(0).transpose(0, 1)
                attentions = attentions.squeeze(0).transpose(0, 1)
                image = torch.repeat_interleave(image, sentences.shape[0], dim=0)
                target = torch.repeat_interleave(target, sentences.shape[0], dim=0)
                output = model(image, sentences, l_mask=attentions, is_train=False)

            for i in range(output.shape[0]):
                iou, I, U = IoU(output[i].unsqueeze(0), target[i])
                acc_ious += iou
                mean_IoU.append(iou)
                cum_I += I
                cum_U += U
                for n_eval_iou in range(len(eval_seg_iou_list)):
                    eval_seg_iou = eval_seg_iou_list[n_eval_iou]
                    seg_correct[n_eval_iou] += (iou >= eval_seg_iou)
                seg_total += 1
        iou = acc_ious / seg_total

    mean_IoU = np.array(mean_IoU)
    mIoU = np.mean(mean_IoU)
    print('Final results:')
    print('Mean IoU is %.2f\n' % (mIoU * 100.))
    results_str = ''

    precs = []
    for n_eval_iou in range(len(eval_seg_iou_list)):
        results_str += '    precision@%s = %.2f\n' % \
                       (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
        precs.append(seg_correct[n_eval_iou] * 100. / seg_total)
    results_str += '    overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
    print(results_str)

    return 100 * iou, 100 * cum_I / cum_U, precs


def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq,
                    iterations, bert_model, metric_learning=False, args=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    header = 'Epoch: [{}]'.format(epoch)
    train_loss = 0
    total_its = 0
    mlw = args.metric_loss_weight
    metric_mode = args.metric_mode
    if not metric_learning:
        mlw = 0

    for data in metric_logger.log_every(data_loader, print_freq, header):
        #print("data : ", data)
        total_its += 1
        if 'hardpos_only' in metric_mode :
            image, target, sentences, attentions, pos_sent, pos_attn_mask = data
            image, target, sentences, attentions, pos_sent, pos_attn_mask = \
                                                    image.cuda(non_blocking=True),\
                                                    target.cuda(non_blocking=True),\
                                                    sentences.cuda(non_blocking=True),\
                                                    attentions.cuda(non_blocking=True),\
                                                    pos_sent.cuda(non_blocking=True),\
                                                    pos_attn_mask.cuda(non_blocking=True)
                                                    
            sentences = sentences.squeeze(1)
            attentions = attentions.squeeze(1)
            pos_sent = pos_sent.squeeze(1)
            pos_attn_mask = pos_attn_mask.squeeze(1)
            # print(image.shape, target.shape, sentences.shape, attentions.shape, pos_sent.shape, pos_attn_mask.shape)
        
        else : 
            image, target, sentences, attentions, pos_sent, pos_attn_mask, neg_sent, neg_attn_mask= data
            image, target, sentences, attentions, pos_sent, pos_attn_mask, neg_sent, neg_attn_mask = \
                                                    image.cuda(non_blocking=True),\
                                                    target.cuda(non_blocking=True),\
                                                    sentences.cuda(non_blocking=True),\
                                                    attentions.cuda(non_blocking=True),\
                                                    pos_sent.cuda(non_blocking=True),\
                                                    pos_attn_mask.cuda(non_blocking=True),\
                                                    neg_sent.cuda(non_blocking=True),\
                                                    neg_attn_mask.cuda(non_blocking=True)

            sentences = sentences.squeeze(1)
            attentions = attentions.squeeze(1)
            pos_sent = pos_sent.squeeze(1)
            pos_attn_mask = pos_attn_mask.squeeze(1)
            neg_sent = neg_sent.squeeze(1)
            neg_attn_mask = neg_attn_mask.squeeze(1)

        loss = 0
        metric_loss = 0

        if bert_model is not None:
            last_hidden_states = bert_model(sentences, attention_mask=attentions)[0]  # (6, 10, 768)
            embedding = last_hidden_states.permute(0, 2, 1)  # (B, 768, N_l) to make Conv1d happy
            attentions = attentions.unsqueeze(dim=-1)  # (batch, N_l, 1)
            output = model(image, embedding, l_mask=attentions)

        else:
            assert pos_sent is not None, "pos_sent must be provided"

            pos_mask = pos_attn_mask.sum(dim=-1) > 0  # Boolean mask for positive samples
            
            verb_masks = [] 
            cl_masks = []
            images = []  
            targets = []
            sentences_ = []
            attentions_ = []
            
            for idx in range(len(image)) : 
                # Append original data
                sentences_.append(sentences[idx])
                images.append(image[idx])
                targets.append(target[idx])
                attentions_.append(attentions[idx])

                if pos_mask[idx]:
                    verb_masks.extend([1, 1])
                    cl_masks.extend([1, 0])
                    sentences_.append(pos_sent[idx])
                    images.append(image[idx])
                    targets.append(target[idx])
                    attentions_.append(pos_attn_mask[idx])

                else:
                    verb_masks.append(0)
                    cl_masks.append(1)                    

            sentences = torch.stack(sentences_)
            image = torch.stack(images)
            target = torch.stack(targets)
            attentions = torch.stack(attentions_)
            verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
            cl_masks = torch.tensor(cl_masks, dtype=torch.bool) 
                 
            # image = torch.cat([image, image[pos_mask]], dim=0)
            # target = torch.cat([target, target[pos_mask]], dim=0)
            # sentences = torch.cat([sentences, pos_sent[pos_mask]], dim=0)
            # attentions = torch.cat([attentions, pos_attn_mask[pos_mask]], dim=0)
    
            # print(pos_mask, image.shape, target.shape, sentences.shape, attentions.shape)

            output, metric_tensors = model(image, sentences, l_mask=attentions)

            # print(output.shape, metric_tensors.shape, output[cl_masks].shape, target[cl_masks].shape)
            
            ce_loss = criterion(output[cl_masks], target[cl_masks]) # can't we just detach output and target?

            if metric_learning and sum(pos_mask) > 0:
                # verbonly option
                # metric_loss = UniAngularContrastLoss(metric_tensors, verb_masks, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
                metric_loss = UniAngularLogitContrastLoss(metric_tensors, verb_masks, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
                                
        loss = (ce_loss + metric_loss * mlw) / (1+mlw)
        optimizer.zero_grad()  # set_to_none=True is only available in pytorch 1.6+
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        torch.cuda.synchronize()
        train_loss += loss.item()
        iterations += 1
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])

        del image, target, sentences, attentions, loss, output, data
        if bert_model is not None:
            del last_hidden_states, embedding

        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
    loss_log = {
        'loss': metric_logger.meters['loss'].global_avg
        }
    return iterations, loss_log



def main(args):
    writer = SummaryWriter('./experiments/{}/{}'.format("_".join([args.dataset, args.splitBy]), args.model_id))
    
    dataset, num_classes = get_dataset("train",
                                       get_transform(args=args),
                                       args=args)
    dataset_test, _ = get_dataset("val",
                                  get_transform(args=args),
                                  args=args)

    # batch sampler
    print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.")
    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
                                                                    shuffle=True)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    # data loader
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=train_sampler, num_workers=args.workers, pin_memory=args.pin_mem, drop_last=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers)

    # model initialization
    print(args.model)
    model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights,
                                              args=args)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
    single_model = model.module

    if args.model != 'lavt_one':
        model_class = BertModel
        bert_model = model_class.from_pretrained(args.ck_bert)
        bert_model.pooler = None  # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel
        bert_model.cuda()
        bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model)
        bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank])
        single_bert_model = bert_model.module
    else:
        bert_model = None
        single_bert_model = None

    # resume training
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        single_model.load_state_dict(checkpoint['model'])
        if args.model != 'lavt_one':
            single_bert_model.load_state_dict(checkpoint['bert_model'])

    # parameters to optimize
    backbone_no_decay = list()
    backbone_decay = list()
    for name, m in single_model.backbone.named_parameters():
        if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name:
            backbone_no_decay.append(m)
        else:
            backbone_decay.append(m)

    if args.model != 'lavt_one':
        params_to_optimize = [
            {'params': backbone_no_decay, 'weight_decay': 0.0},
            {'params': backbone_decay},
            {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
            # the following are the parameters of bert
            {"params": reduce(operator.concat,
                              [[p for p in single_bert_model.encoder.layer[i].parameters()
                                if p.requires_grad] for i in range(10)])},
        ]
    else:
        params_to_optimize = [
            {'params': backbone_no_decay, 'weight_decay': 0.0},
            {'params': backbone_decay},
            {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
            # the following are the parameters of bert
            {"params": reduce(operator.concat,
                              [[p for p in single_model.text_encoder.encoder.layer[i].parameters()
                                if p.requires_grad] for i in range(10)]), 'lr': args.lr/10},
        ]

    # optimizer
    optimizer = torch.optim.AdamW(params_to_optimize,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay,
                                  amsgrad=args.amsgrad
                                  )

    # learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)

    # housekeeping
    start_time = time.time()
    iterations = 0
    best_oIoU = -0.1

    # resume training (optimizer, lr scheduler, and the epoch)
    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        resume_epoch = checkpoint['epoch']
    else:
        resume_epoch = -999

    # training loops
    for epoch in range(max(0, resume_epoch+1), args.epochs):
        data_loader.sampler.set_epoch(epoch)
        itrs_temp, loss_log = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq,
                        iterations, bert_model, metric_learning=args.metric_learning, args=args)
        iou, overallIoU, precs = evaluate(model, data_loader_test, bert_model)

        print('Average object IoU {}'.format(iou))
        print('Overall IoU {}'.format(overallIoU))
        save_checkpoint = (best_oIoU < overallIoU)
        if save_checkpoint:
            print('Better epoch: {}\n'.format(epoch))
            if single_bert_model is not None:
                dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(),
                                'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
                                'lr_scheduler': lr_scheduler.state_dict()}
            else:
                dict_to_save = {'model': single_model.state_dict(),
                                'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
                                'lr_scheduler': lr_scheduler.state_dict()}

            utils.save_on_master(dict_to_save, os.path.join(args.output_dir,
                                                            'model_best_{}.pth'.format(args.model_id)))
            best_oIoU = overallIoU


        if utils.is_main_process():
            writer.add_scalar('val/mIoU', iou, epoch)
            writer.add_scalar('val/oIoU', overallIoU, epoch)
            writer.add_scalar('val/Prec/50', precs[0], epoch)
            writer.add_scalar('val/Prec/60', precs[1], epoch)
            writer.add_scalar('val/Prec/70', precs[2], epoch)
            writer.add_scalar('val/Prec/80', precs[3], epoch)
            writer.add_scalar('val/Prec/90', precs[4], epoch)
            writer.add_scalar('train/loss', loss_log['loss'], epoch)

    writer.flush()

    # summarize
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == "__main__":
    from args import get_parser
    parser = get_parser()
    args = parser.parse_args()
    # set up distributed learning
    utils.init_distributed_mode(args)
    print('Image size: {}'.format(str(args.img_size)))
    main(args)