import torch, argparse
from commonsense_model import CommonsenseGRUModel
from dataloader import RobertaCometDataset
from torch.utils.data import DataLoader


def load_model(model_path, args):
    emo_gru = True
    n_classes = 15
    cuda = args.cuda

    D_m = 1024
    D_s = 768
    D_g = 150
    D_p = 150
    D_r = 150
    D_i = 150
    D_h = 100
    D_a = 100
    D_e = D_p + D_r + D_i

    model = CommonsenseGRUModel(
        D_m,
        D_s,
        D_g,
        D_p,
        D_r,
        D_i,
        D_e,
        D_h,
        D_a,
        n_classes=n_classes,
        listener_state=args.active_listener,
        context_attention=args.attention,
        dropout_rec=args.rec_dropout,
        dropout=args.dropout,
        emo_gru=emo_gru,
        mode1=args.mode1,
        norm=args.norm,
        residual=args.residual,
    )

    if cuda:
        model.cuda()
        model.load_state_dict(torch.load(model_path))
    else:
        model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))

    model.eval()

    return model


def get_valid_dataloader(
    roberta_features_path: str,
    comet_features_path: str,
    batch_size=1,
    num_workers=0,
    pin_memory=False,
):
    valid_set = RobertaCometDataset("valid", roberta_features_path, comet_features_path)

    test_loader = DataLoader(
        valid_set,
        batch_size=batch_size,
        collate_fn=valid_set.collate_fn,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return test_loader, valid_set.keys


def predict(model, data_loader, args):
    predictions = []
    for data in data_loader:
        r1, r2, r3, r4, x1, x2, x3, x4, x5, x6, o1, o2, o3, qmask, umask, label = (
            [d.cuda() for d in data[:-1]] if args.cuda else data[:-1]
        )
        log_prob, _, alpha, alpha_f, alpha_b, _ = model(
            r1, r2, r3, r4, x5, x6, x1, o2, o3, qmask, umask
        )

        lp_ = log_prob.transpose(0, 1).contiguous().view(-1, log_prob.size()[2])
        preds = torch.argmax(lp_, dim=-1)
        predictions.append(preds.data.cpu().numpy())

    return predictions


def parse_cosmic_args():
    parser = argparse.ArgumentParser()

    # Parse arguments input into the cosmic model
    parser.add_argument(
        "--no-cuda", action="store_true", default=True, help="does not use GPU"
    )
    parser.add_argument(
        "--lr", type=float, default=0.0001, metavar="LR", help="learning rate"
    )
    parser.add_argument(
        "--l2",
        type=float,
        default=0.00003,
        metavar="L2",
        help="L2 regularization weight",
    )
    parser.add_argument(
        "--rec-dropout",
        type=float,
        default=0.3,
        metavar="rec_dropout",
        help="rec_dropout rate",
    )
    parser.add_argument(
        "--dropout", type=float, default=0.5, metavar="dropout", help="dropout rate"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1, metavar="BS", help="batch size"
    )
    parser.add_argument(
        "--epochs", type=int, default=10, metavar="E", help="number of epochs"
    )
    parser.add_argument(
        "--class-weight", action="store_true", default=True, help="use class weights"
    )
    parser.add_argument(
        "--active-listener", action="store_true", default=True, help="active listener"
    )
    parser.add_argument(
        "--attention", default="simple", help="Attention type in context GRU"
    )
    parser.add_argument(
        "--tensorboard",
        action="store_true",
        default=False,
        help="Enables tensorboard log",
    )
    parser.add_argument("--mode1", type=int, default=2, help="Roberta features to use")
    parser.add_argument("--seed", type=int, default=500, metavar="seed", help="seed")
    parser.add_argument("--norm", type=int, default=0, help="normalization strategy")
    parser.add_argument("--mu", type=float, default=0, help="class_weight_mu")
    parser.add_argument(
        "--residual", action="store_true", default=True, help="use residual connection"
    )

    args = parser.parse_args()

    args.cuda = torch.cuda.is_available() and not args.no_cuda
    if args.cuda:
        print("Running on GPU")
    else:
        print("Running on CPU")

    return args


if __name__ == "__main__":

    def pred_to_labels(preds):
        mapped_predictions = []
        for pred in preds:
            # map the prediction for each conversation
            mapped_labels = []
            for label in pred:
                mapped_labels.append(label_mapping[label])

            mapped_predictions.append(mapped_labels)

            # return the mapped labels for each conversation
            return mapped_predictions

    label_mapping = {
        0: "Curiosity",
        1: "Obscene",
        2: "Informative",
        3: "Openness",
        4: "Acceptance",
        5: "Interest",
        6: "Greeting",
        7: "Disapproval",
        8: "Denial",
        9: "Anxious",
        10: "Uninterested",
        11: "Remorse",
        12: "Confused",
        13: "Accusatory",
        14: "Annoyed",
    }

    args = parse_cosmic_args()

    model = load_model("epik/best_model.pt", args)
    test_dataloader, ids = get_valid_dataloader()
    predicted_labels = pred_to_labels(predict(model, test_dataloader, args))

    for id, labels in zip(ids, predicted_labels):
        print(f"Conversation ID: {id}")
        print(f"Predicted Sentiment Labels: {labels}")
        print(len(labels))