import copy
import torch
from torch import nn
from transformers import AutoModel
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
# from torchcrf import CRF

class MyModel(nn.Module):
    def __init__(self, args, backbone):
        super().__init__()
        self.args = args
        self.backbone = backbone
        self.cls_id = 0
        hidden_dim = self.backbone.config.hidden_size
        self.classifier = nn.Sequential(
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, args.num_labels)
                )

        if args.distil_att:
            self.distil_att = nn.Parameter(torch.ones(self.backbone.config.hidden_size))

    def forward(self, x, mask):
        x = x.to(self.backbone.device)
        mask = mask.to(self.backbone.device)
        out = self.backbone(x, attention_mask = mask, output_attentions=True)
        return out, self.classifier(out.last_hidden_state)

    def decisions(self, x, mask):
        x = x.to(self.backbone.device)
        mask = mask.to(self.backbone.device)
        out = self.backbone(x, attention_mask = mask, output_attentions=False)
        return out, self.classifier(out.last_hidden_state)

    def phenos(self, x, mask):
        x = x.to(self.backbone.device)
        mask = mask.to(self.backbone.device)
        out = self.backbone(x, attention_mask = mask, output_attentions=True)
        return out, self.classifier(out.pooler_output)

    def generate(self, x, mask, choice=None):
        outs = []
        if self.args.task == 'seq' or choice == 'seq':
            for i, offset in enumerate(range(0, x.shape[1], self.args.max_len-1)):
                if i == 0:
                    segment = x[:, offset:offset + self.args.max_len-1]
                    segment_mask = mask[:, offset:offset + self.args.max_len-1]
                else:
                    segment = torch.cat((torch.ones((x.shape[0], 1), dtype=int).to(x.device)\
                            *self.cls_id,
                            x[:, offset:offset + self.args.max_len-1]), axis=1)
                    segment_mask = torch.cat((torch.ones((mask.shape[0], 1)).to(mask.device),
                            mask[:, offset:offset + self.args.max_len-1]), axis=1)
                logits = self.phenos(segment, segment_mask)[1]
                outs.append(logits)

            return torch.max(torch.stack(outs, 1), 1).values
        elif self.args.task == 'token':
            for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
                segment = x[:, offset:offset + self.args.max_len]
                segment_mask = mask[:, offset:offset + self.args.max_len]
                h = self.decisions(segment, segment_mask)[0].last_hidden_state
                outs.append(h)
            h = torch.cat(outs, 1)
            return self.classifier(h)

class CNN(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.emb = nn.Embedding(args.vocab_size, args.emb_size)
        self.model = nn.Sequential(
                nn.Conv1d(args.emb_size, args.hidden_size, args.kernels[0],
                    padding='same' if args.task == 'token' else 'valid'),
                nn.ReLU(),
                nn.MaxPool1d(1),
                nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[1],
                    padding='same' if args.task == 'token' else 'valid'),
                nn.ReLU(),
                nn.MaxPool1d(1),
                nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[2],
                    padding='same' if args.task == 'token' else 'valid'),
                nn.ReLU(),
                nn.MaxPool1d(1),
                )
        if args.task == 'seq':
            out_shape = 512 - args.kernels[0] - args.kernels[1] - args.kernels[2] + 3
        elif args.task == 'token':
            out_shape = 1
        self.classifier = nn.Linear(args.hidden_size*out_shape, args.num_labels)
        self.dropout = nn.Dropout()
        self.args = args
        self.device = None

    def forward(self, x, _):
        x = x.to(self.device)
        bs = x.shape[0]
        x = self.emb(x)
        x = x.transpose(1,2)
        x = self.model(x)
        x = self.dropout(x)
        if self.args.task == 'token':
            x = x.transpose(1,2)
            h = self.classifier(x)
            return x, h
        elif self.args.task == 'seq':
            x = x.reshape(bs, -1)
            x = self.classifier(x)
            return x

    def generate(self, x, _):
        outs = []
        for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
            segment = x[:, offset:offset + self.args.max_len]
            n = segment.shape[1]
            if n != self.args.max_len:
                segment = torch.nn.functional.pad(segment, (0, self.args.max_len -  n))
            if self.args.task == 'seq':
                logits = self(segment, None)
                outs.append(logits)
            elif self.args.task == 'token':
                h = self(segment, None)[0]
                h = h[:,:n]
                outs.append(h)
        if self.args.task == 'seq':
            return torch.max(torch.stack(outs, 1), 1).values
        elif self.args.task == 'token':
            h = torch.cat(outs, 1)
            return self.classifier(h)

class LSTM(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.emb = nn.Embedding(args.vocab_size, args.emb_size)
        self.model = nn.LSTM(args.emb_size, args.hidden_size, num_layers=args.num_layers,
                batch_first=True, bidirectional=True)
        dim = 2*args.num_layers*args.hidden_size if args.task == 'seq' else 2*args.hidden_size
        self.classifier = nn.Linear(dim, args.num_labels)
        self.dropout = nn.Dropout()
        self.args = args
        self.device = None

    def forward(self, x, _):
        x = x.to(self.device)
        x = self.emb(x)
        o, (x, _) = self.model(x)
        o_out = self.classifier(o) if self.args.task == 'token' else None
        if self.args.task == 'seq':
            x = torch.cat([h for h in x], 1)
            x = self.dropout(x)
            x = self.classifier(x)
        return (x, o), o_out

    def generate(self, x, _):
        outs = []
        for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
            segment = x[:, offset:offset + self.args.max_len]
            if self.args.task == 'seq':
                logits = self(segment, None)[0][0]
                outs.append(logits)
            elif self.args.task == 'token':
                h = self(segment, None)[0][1]
                outs.append(h)
        if self.args.task == 'seq':
            return torch.max(torch.stack(outs, 1), 1).values
        elif self.args.task == 'token':
            h = torch.cat(outs, 1)
            return self.classifier(h)

def load_model(args, device):
    if args.model == 'lstm':
        model = LSTM(args).to(device)
        model.device = device
    elif args.model == 'cnn':
        model = CNN(args).to(device)
        model.device = device
    else:
        model = MyModel(args, AutoModel.from_pretrained(args.model_name)).to(device)
    if args.ckpt:
        model.load_state_dict(torch.load(args.ckpt, map_location=device), strict=False)
    if args.distil:
        args2 = copy.deepcopy(args)
        args2.task = 'token'
        # args2.num_labels = args.num_decs
        args2.num_labels = args.num_umls_tags
        model_B = MyModel(args2, AutoModel.from_pretrained(args.model_name)).to(device)
        model_B.load_state_dict(torch.load(args.distil_ckpt, map_location=device), strict=False)
        for p in model_B.parameters():
            p.requires_grad = False
    else:
        model_B = None
    if args.label_encoding == 'multiclass':
        if args.use_crf:
            crit = CRF(args.num_labels, batch_first = True).to(device)
        else:
            crit = nn.CrossEntropyLoss(reduction='none')
    else:
        crit = nn.BCEWithLogitsLoss(
                pos_weight=torch.ones(args.num_labels).to(device)*args.pos_weight,
                reduction='none'
                )
    optimizer = AdamW(model.parameters(), lr=args.lr)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer,
            int(0.1*args.total_steps), args.total_steps)

    return model, crit, optimizer, lr_scheduler, model_B