import gc

import numpy as np
import pandas as pd

import torch
from torch import nn
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig


config = dict(
    # basic
    seed = 3407,
    num_jobs=1,
    num_labels=2,
    
    # model info
    tokenizer_path = 'roberta-large', # 'allenai/biomed_roberta_base',
    model_checkpoint = 'roberta-large', # 'allenai/biomed_roberta_base', 
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # training paramters
    max_length = 512,
    batch_size=16,
    
    # for this notebook
    debug = False,
)


def create_sample_test():
    feats = pd.read_csv(f"../input/nbme-score-clinical-patient-notes/features.csv")
    feats.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"
    
    notes = pd.read_csv(f"../input/nbme-score-clinical-patient-notes/patient_notes.csv")
    test = pd.read_csv(f"../input/nbme-score-clinical-patient-notes/test.csv")

    merged = test.merge(notes, how = "left")
    merged = merged.merge(feats, how = "left")

    def process_feature_text(text):
        return text.replace("-OR-", ";-").replace("-", " ")
    merged["feature_text"] = [process_feature_text(x) for x in merged["feature_text"]]
    
    return merged.sample(1).reset_index(drop=True)

class NBMETestData(torch.utils.data.Dataset):
    def __init__(self, feature_text, pn_history, tokenizer):
        self.feature_text = feature_text
        self.pn_history = pn_history
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.feature_text)
    
    def __getitem__(self, idx):
        tokenized = self.tokenizer(
            self.feature_text[idx],
            self.pn_history[idx],
            truncation = "only_second",
            max_length = config['max_length'],
            padding = "max_length",
            return_offsets_mapping = True
        )
        tokenized["sequence_ids"] = tokenized.sequence_ids()

        input_ids = np.array(tokenized["input_ids"])
        attention_mask = np.array(tokenized["attention_mask"])
        offset_mapping = np.array(tokenized["offset_mapping"])
        sequence_ids = np.array(tokenized["sequence_ids"]).astype("float16")

        return {
            'input_ids': input_ids, 
            'attention_mask': attention_mask, 
            'offset_mapping': offset_mapping, 
            'sequence_ids': sequence_ids,
        }

# class NBMEModel(nn.Module):
#     def __init__(self, num_labels=1, path=None):
#         super().__init__()
        
#         layer_norm_eps: float = 1e-6
        
#         self.path = path
#         self.num_labels = num_labels
        
#         self.transformer = transformers.AutoModel.from_pretrained(config['model_checkpoint'])
#         self.dropout = nn.Dropout(0.2)
#         self.output = nn.Linear(768, 1)
        
#         if self.path is not None:
#             self.load_state_dict(torch.load(self.path)['model'])

#     def forward(self, data):
        
#         ids = data['input_ids']
#         mask = data['attention_mask']
#         try:
#             target = data['targets']
#         except:
#             target = None

#         transformer_out = self.transformer(ids, mask)
#         sequence_output = transformer_out[0]
#         sequence_output = self.dropout(sequence_output)
#         logits = self.output(sequence_output)

#         ret = {
#             "logits": torch.sigmoid(logits),
#         }
        
#         if target is not None:
#             loss = self.get_loss(logits, target)
#             ret['loss'] = loss
#             ret['targets'] = target

#         return ret

        
#     def get_optimizer(self, learning_rate, weigth_decay):
#         optimizer = torch.optim.AdamW(
#             self.parameters(), 
#             lr=learning_rate, 
#             weight_decay=weigth_decay,
#         )
#         if self.path is not None:
#             optimizer.load_state_dict(torch.load(self.path)['optimizer'])
        
#         return optimizer
            
#     def get_scheduler(self, optimizer, num_warmup_steps, num_training_steps):
#         scheduler = transformers.get_linear_schedule_with_warmup(
#             optimizer,
#             num_warmup_steps=num_warmup_steps,
#             num_training_steps=num_training_steps,
#         )
#         if self.path is not None:
#             scheduler.load_state_dict(torch.load(self.path)['scheduler'])
            
#         return scheduler
    
#     def get_loss(self, output, target):
#         loss_fn = nn.BCEWithLogitsLoss(reduction="none")
#         loss = loss_fn(output.view(-1, 1), target.view(-1, 1))
#         loss = torch.masked_select(loss, target.view(-1, 1) != -100).mean()
#         return loss


class NBMEModel(nn.Module):
    def __init__(self, num_labels=2, path=None):
        super().__init__()
        
        layer_norm_eps: float = 1e-6
        
        self.path = path
        self.num_labels = num_labels
        self.transformer = transformers.AutoModel.from_pretrained(config['model_checkpoint'])
        self.dropout = nn.Dropout(0.1)
        
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.2)
        self.dropout3 = nn.Dropout(0.3)
        self.dropout4 = nn.Dropout(0.4)
        self.dropout5 = nn.Dropout(0.5)
        
        self.output = nn.Linear(1024, 1)
        
        if self.path is not None:
            self.load_state_dict(torch.load(self.path)['model'])
    
    def forward(self, data):
        
        ids = data['input_ids']
        mask = data['attention_mask']
        try:
            target = data['targets']
        except:
            target = None

        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out[0]
        sequence_output = self.dropout(sequence_output)
    
        logits1 = self.output(self.dropout1(sequence_output))
        logits2 = self.output(self.dropout2(sequence_output))
        logits3 = self.output(self.dropout3(sequence_output))
        logits4 = self.output(self.dropout4(sequence_output))
        logits5 = self.output(self.dropout5(sequence_output))

        logits = (logits1 + logits2 + logits3 + logits4 + logits5) / 5
        ret = {
            'logits': torch.sigmoid(logits), 
        }
        
        loss = 0

        if target is not None:
            loss1 = self.get_loss(logits1, target)
            loss2 = self.get_loss(logits2, target)
            loss3 = self.get_loss(logits3, target)
            loss4 = self.get_loss(logits4, target)
            loss5 = self.get_loss(logits5, target)
            loss = (loss1 + loss2 + loss3 + loss4 + loss5) / 5
            ret['loss'] = loss
            ret['target'] = target

        return ret

        
    def get_optimizer(self, learning_rate, weigth_decay):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=learning_rate, 
            weight_decay=weigth_decay,
        )
        if self.path is not None:
            optimizer.load_state_dict(torch.load(self.path)['optimizer'])
        
        return optimizer
            
    def get_scheduler(self, optimizer, num_warmup_steps, num_training_steps):
        scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
        )
        if self.path is not None:
            scheduler.load_state_dict(torch.load(self.path)['scheduler'])
            
        return scheduler
    
    def get_loss(self, output, target):
        loss_fn = nn.BCEWithLogitsLoss(reduction="none")
        loss = loss_fn(output.view(-1, 1), target.view(-1, 1))
        loss = torch.masked_select(loss, target.view(-1, 1) != -100).mean()
        return loss


def get_location_predictions(preds, offset_mapping, sequence_ids, test=False):
    all_predictions = []
    for pred, offsets, seq_ids in zip(preds, offset_mapping, sequence_ids):
        start_idx = None
        current_preds = []
        for p, o, s_id in zip(pred, offsets, seq_ids):
            if s_id is None or s_id == 0:
                continue
            if p > 0.5:
                if start_idx is None:
                    start_idx = o[0]
                end_idx = o[1]
            elif start_idx is not None:
                if test:
                    current_preds.append(f"{start_idx} {end_idx}")
                else:
                    current_preds.append((start_idx, end_idx))
                start_idx = None
        if test:
            all_predictions.append("; ".join(current_preds))
        else:
            all_predictions.append(current_preds)
    return all_predictions



def predict_location_preds(tokenizer, model, feature_text, pn_history, pn_history_lower):

    test_ds = NBMETestData(feature_text, pn_history_lower, tokenizer)
    test_dl = torch.utils.data.DataLoader(
        test_ds, 
        batch_size=config['batch_size'], 
        pin_memory=True, 
        shuffle=False, 
        drop_last=False
    )

    all_preds = None
    offsets = []
    seq_ids = []

    preds = []

    with torch.no_grad():
        for batch in test_dl:

            for k, v in batch.items():
                if k not in  ['offset_mapping', 'sequence_id']:
                    batch[k] = v.to(config['device'])

            logits = model(batch)['logits']
            preds.append(logits.cpu().numpy())

            offset_mapping = batch['offset_mapping']
            sequence_ids = batch['sequence_ids']
            offsets.append(offset_mapping.cpu().numpy())
            seq_ids.append(sequence_ids.cpu().numpy())

    preds = np.concatenate(preds, axis=0)
    if all_preds is None:
        all_preds = np.array(preds).astype(np.float32)
    else:
        all_preds += np.array(preds).astype(np.float32)
    torch.cuda.empty_cache()

    all_preds = all_preds.squeeze()

    offsets = np.concatenate(offsets, axis=0)
    seq_ids = np.concatenate(seq_ids, axis=0)

    # print(all_preds.shape, offsets.shape, seq_ids.shape)

    location_preds = get_location_predictions([all_preds], offsets, seq_ids, test=False)[0]
    
    x = []
    
    for location in location_preds:
        x.append(pn_history[0][location[0]: location[1]])
    
    return location_preds, ', '.join(x)

def get_predictions(feature_text, pn_history):
    feature_text = feature_text.lower().replace("-OR-", ";-").replace("-", " ")
    pn_history_lower = pn_history.lower()
    
    location_preds, pred_string = predict_location_preds(tokenizer, model, [feature_text], [pn_history], [pn_history_lower])
    
    if pred_string == "":
        pred_string = 'Feature not present!'
    else:
        pred_string = 'Feature is present!' + '\nText Span - ' + pred_string
    
    return pred_string

tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_path'])
path = 'model_large_pseudo_label.pth'

model = NBMEModel().to(config['device'])
model.load_state_dict(
    torch.load(
        path, 
        map_location=torch.device(config['device'])
    )
)
model.eval()