from sentence_transformers import util
from tqdm import tqdm

def calc_recall(true_pos, false_neg, eps=1e-8):
    return true_pos / (true_pos + false_neg + eps)



def calc_precision(true_pos, false_pos, eps=1e-8):
    return true_pos / (true_pos + false_pos + eps)



def calc_f1_score(precision, recall, eps=1e-8):
    return (2*precision*recall) / (precision + recall + eps)



def calc_metrics(true, predicted, model, threshold=0.95, eps=1e-8):
    true_pos = 0
    false_pos = 0
    false_neg = 0
    
    false_pos_ids = []
    false_neg_ids = []
    
    i = 0
    total = len(true)
    for j, (true_ents, pred_ents) in tqdm(enumerate(zip(true, predicted))):
        i += 1
        # print(f'{i}/{total}')
        # print('----------------------------')
        
        if len(true_ents) == 0:
            false_pos += len(pred_ents)
            
            if len(pred_ents) > 0:
                false_pos_ids.append(j)
            
            continue
            
        if len(pred_ents) == 0:
            false_neg += len(true_ents)
            
            if len(true_ents) > 0:
                # print('False Negative')
                false_neg_ids.append(j)
            
            continue
        
        embed_true = model.encode(true_ents, convert_to_tensor=True)
        embed_pred = model.encode(pred_ents, convert_to_tensor=True)

        similarities = util.pytorch_cos_sim(embed_true, embed_pred)
        # similarities = model.similarity(true_ents, pred_ents, device='cuda')
        
        for row in similarities:
            if (row >= threshold).any():
                true_pos += 1
            else:
                false_neg += 1
                # print('False Negative 2222222')
                false_neg_ids.append(j)

        for row in similarities.T:
            if (row >= threshold).any():
                continue
            else:
                false_pos += 1
                false_pos_ids.append(j)
                
    recall = calc_recall(true_pos, false_neg)
    precision = calc_precision(true_pos, false_pos)
    f1_score = calc_f1_score(precision, recall, eps=eps)
    
    return {
        # 'true_pos': true_pos,
        # 'false_pos': false_pos,
        # 'false_neg': false_neg,
        'recall': recall,
        'precision': precision,
        'f1': f1_score,
        # 'false_pos_ids': list(set(false_pos_ids)),
        # 'false_neg_ids': list(set(false_neg_ids))
    }