import torch
import time
import pandas as pd
import numpy as np
import pickle
import os

from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
import fuson_plm.benchmarking.puncta.config as config
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy, get_local_time, CustomParams

def check_splits(df):
    # make sure everything has a split
    if len(df.loc[df['split'].isna()])>0:
        raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)")
    # make sure the only things are train and test
    if len({'train','test'} - set(df['split'].unique()))!=0:
        raise Exception("Error: splits column should only have \'train\' and \'test\'.")
    # make sure there are no duplicate sequences
    if len(df.loc[df['aa_seq'].duplicated()])>0:
        raise Exception("Error: duplicate sequences provided")
    
def train_and_evaluate_puncta_predictor(details, splits_with_embeddings,outdir,task='nucleus',class1_thresh=0.5,n_estimators=50,tree_method="hist"):
    """
    task = 'nucleus', 'cytoplasm', or 'formation'
    """
    # unpack the details dictioanry
    benchmark_model_type = details['model_type']
    benchmark_model_name = details['model']
    benchmark_model_epoch = details['epoch']
    
    # prepare train and test sets for model
    train_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='train'].reset_index(drop=True)
    test_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='test'].reset_index(drop=True)
    
    X_train = np.array(train_split['embedding'].tolist())
    y_train = np.array(train_split[task].tolist())
    X_test = np.array(test_split['embedding'].tolist())
    y_test = np.array(test_split[task].tolist())
        
    # Train the final model on all the data
    clf = train_final_predictor(X_train, y_train, n_estimators=n_estimators, tree_method=tree_method)
    
    # Evaluate it
    automatic_stats_df, custom_stats_df = evaluate_predictor(clf, X_test, y_test, class1_thresh=class1_thresh)
    
    # Add the model details back in
    cols = list(automatic_stats_df.columns)
    automatic_stats_df['Model Type'] = [benchmark_model_type]
    automatic_stats_df['Model Name'] = [benchmark_model_name]
    automatic_stats_df['Model Epoch'] = [benchmark_model_epoch]
    newcols = ['Model Type','Model Name','Model Epoch'] + cols
    automatic_stats_df = automatic_stats_df[newcols]
    
    cols = list(custom_stats_df.columns)
    custom_stats_df['Model Type'] = [benchmark_model_type]
    custom_stats_df['Model Name'] = [benchmark_model_name]
    custom_stats_df['Model Epoch'] = [benchmark_model_epoch]
    newcols = ['Model Type','Model Name','Model Epoch'] + cols
    custom_stats_df = custom_stats_df[newcols]
    
    # Save automatic results (for nucleus and cytoplasm)
    if task!="formation":
        automatic_stats_path = f'{outdir}/{task}_verificationFOs_results.csv'
        if not(os.path.exists(automatic_stats_path)):
            automatic_stats_df.to_csv(automatic_stats_path,index=False)
        else:
            automatic_stats_df.to_csv(automatic_stats_path,mode='a',index=False,header=False)
    
    # Save custom threshold results (only if it's formation)
    if task=="formation":
        custom_stats_path = f'{outdir}/{task}_verificationFOs_{class1_thresh}thresh_results.csv'
        if not(os.path.exists(custom_stats_path)):
            custom_stats_df.to_csv(custom_stats_path,index=False)
        else:
            custom_stats_df.to_csv(custom_stats_path,mode='a',index=False,header=False)
    
def main():
    # make output directory for this run
    os.makedirs('results',exist_ok=True)
    output_dir = f'results/{get_local_time()}'
    os.makedirs(output_dir,exist_ok=True)
    
    with open_logfile(f'{output_dir}/puncta_benchmark_log.txt'):
        # print configurations 
        print_configpy(config)
        
        # Verify that the environment variables are set correctly 
        os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
        log_update("\nChecking on environment variables...")
        log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
        
        # make embeddings if needed
        all_embedding_paths = embed_dataset_for_benchmark(
                                            fuson_ckpts=config.FUSONPLM_CKPTS, 
                                            input_data_path='splits.csv', input_fname='FOdb_puncta_sequences', 
                                            average=True, seq_col='aa_seq',
                                            benchmark_fusonplm=config.BENCHMARK_FUSONPLM, 
                                            benchmark_esm=config.BENCHMARK_ESM, 
                                            benchmark_fo_puncta_ml=config.BENCHMARK_FO_PUNCTA_ML, 
                                            benchmark_prott5 = config.BENCHMARK_PROTT5,
                                            overwrite=config.PERMISSION_TO_OVERWRITE)
        
        # load the splits with labels
        splits = pd.read_csv('splits.csv')
        # perform some sanity checks on the splits
        check_splits(splits)
        n_train = len(splits.loc[splits['split']=='train'])
        n_test = len(splits.loc[splits['split']=='test'])
        log_update(f"\nSplit breakdown...\n\t{n_train} Training FOs\n\t{n_test} Verification FOs")
        
        # set training constants
        train_params = CustomParams(
            N_ESTIMATORS = 50,
            TREE_METHOD = "hist",
            CLASS1_THRESHOLDS = {
                'nucleus': 0.83,
                'cytoplasm': 0.83,
                'formation': 0.83
            },
        )
        log_update("\nTraining configs:")
        train_params.print_config(indent='\t')
        
        log_update("\nTraining models")
        # loop through the embedding paths and train each one
        for embedding_path, details in all_embedding_paths.items():
            log_update(f"\tBenchmarking embeddings at: {embedding_path}")
            try:
                with open(embedding_path, "rb") as f:
                    embeddings = pickle.load(f)
            except: 
                raise Exception(f"Cannot read embeddings from {embedding_path}")
            
            # combine the embeddings and splits into one dataframe
            splits_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
            splits_with_embeddings = splits_with_embeddings.rename(columns={0: 'aa_seq', 1: 'embedding'})
            splits_with_embeddings = pd.merge(splits_with_embeddings, splits, on='aa_seq',how='left')
            
            for task in ['nucleus','cytoplasm','formation']:
                log_update(f"\t\tTask: {task}")
                train_and_evaluate_puncta_predictor(details, splits_with_embeddings, output_dir, task=task,
                                                    class1_thresh=train_params.CLASS1_THRESHOLDS[task],
                                                    n_estimators=train_params.N_ESTIMATORS,tree_method=train_params.TREE_METHOD)
        
        log_update(f"\nMaking summary figures:\n")
        log_update(f"\tbar charts...")
        os.makedirs(f"{output_dir}/figures",exist_ok=True)
        make_all_final_bar_charts(output_dir)
        log_update(f"\tDone.")
            
if __name__ == '__main__':
    main()