from evaluate_model import compute_metrics
from datasets import load_from_disk, Dataset
from transformers import AutoTokenizer
import os
import pickle
from transformers import AutoModelForTokenClassification
# from transformers import DataCollatorForTokenClassification
from utils import tokenize_and_align_labels
from rich import print
import huggingface_hub
import torch
import json
from tqdm import tqdm

# _ = load_dotenv(find_dotenv()) # read local .env file
hf_token= os.environ['HF_TOKEN']
huggingface_hub.login(hf_token)

checkpoint = 'elshehawy/finer-ord-transformers'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

data_path = './data/merged_dataset/'

test = load_from_disk(data_path)['test']
test = Dataset.from_dict(test[:16])

feature_path = './data/ner_feature.pickle'

with open(feature_path, 'rb') as f:
    ner_feature = pickle.load(f)

    
# data_collator  = DataCollatorForTokenClassification(tokenizer=tokenizer)
    
ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)

# tokenized_dataset.set_format('torch')

def collate_fn(data):
    input_ids = [(element['input_ids']) for element in data]
    attention_mask = [element['attention_mask'] for element in data]
    token_type_ids = [element['token_type_ids'] for element in data]
    labels = [element['labels'] for element in data]
    
    return input_ids, token_type_ids, attention_mask, labels


ner_model = ner_model.eval()



def get_metrics_trf(data):
    # print(device)
    
    data = Dataset.from_dict(data)
        
    tokenized_data = data.map(
        tokenize_and_align_labels,
        batched=True,
        batch_size=None,
        remove_columns=data.column_names[2:],
        fn_kwargs={'tokenizer': tokenizer}
    )
    
    loader = torch.utils.data.DataLoader(tokenized_data, batch_size=16, collate_fn=collate_fn)
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    
    y_true, logits = [], []
    for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
        ner_model.to(device)
        with torch.no_grad():
            logits.extend(
                ner_model(
                    input_ids=torch.tensor(input_ids).to(device),
                    token_type_ids=torch.tensor(token_type_ids).to(device),
                    attention_mask=torch.tensor(attention_mask).to(device)
                ).logits.cpu().numpy()
            )

            y_true.extend(labels)


    all_metrics = compute_metrics((logits, y_true))
    return all_metrics

    # with open('./metrics/trf/metrics.json', 'w') as f:
    #     json.dump(all_metrics, f)
    

def find_orgs_in_data(tokens, labels):
    orgs = []
    prev_tok_id = -2
    for i, (token, label) in enumerate(zip(tokens, labels)):
        if label == 'B-ORG':
            org = []
            org.append(token)
            orgs.append(org)
            prev_tok_id = i
        
        if label == 'I-ORG' and (i-1) == prev_tok_id:
            org = orgs[-1]
            org.append(token)
            orgs[-1] = org
            prev_tok_id = i
            # print(i)
            
    return [tokenizer.convert_tokens_to_string(org) for org in orgs] 



def store_sample_data(data):
    data = Dataset.from_dict(data)
    test_data = []

    for sent in data:
        labels = [ner_feature.feature.int2str(l) for l in sent['ner_tags']]
        # print(labels)
        sent_orgs = find_orgs_in_data(sent['tokens'], labels)

        sent_text = tokenizer.convert_tokens_to_string(sent['tokens'])
        test_data.append({
            'id': sent['id'],
            'text': sent_text,
            'orgs': sent_orgs
        })

    return test_data
    # with open('./data/sample_data.json', 'w') as f:
    #     json.dump(test_data, f)