import os
import pandas as pd

from huggingface_hub import HfApi

script_dir = os.path.dirname(os.path.abspath(__file__))  # Directory of the running script



def get_baseline_df(selected_methods, selected_metrics, leaderboard_path="/tmp/leaderboard_results.csv"):
    if not os.path.exists(leaderboard_path):
        benchmark_types = [] #only download leaderboard
        download_from_hub(benchmark_types)

    leaderboard_df = pd.read_csv(leaderboard_path)

    if selected_methods is not None and selected_metrics is not None:
        present_columns = ["Method"] + selected_metrics
        leaderboard_df = leaderboard_df[leaderboard_df['Method'].isin(selected_methods)][present_columns]
    return leaderboard_df


def save_results(method_name, benchmark_types, results, repo_id="HUBioDataLab/probe-data", repo_type="space", temporary=False):
    #First, download files to be updated from {repo_id}
    download_from_hub(benchmark_types, repo_id, repo_type)

    #Update local files
    for benchmark_type in benchmark_types:
        if benchmark_type == 'similarity':
            save_similarity_output(results['similarity'], method_name)
        elif benchmark_type == 'function':
            save_function_output(results['function'], method_name)
        elif benchmark_type == 'family':
            save_family_output(results['family'], method_name)
        elif benchmark_type == "affinity":
            save_affinity_output(results['affinity'], method_name)

    if not temporary:
        #Upload local files to the {repo_id}
        upload_to_hub(benchmark_types, repo_id, repo_type)

    return 0


def download_from_hub(benchmark_types, repo_id="HUBioDataLab/probe-data", repo_type="space"):
    api = HfApi(token=os.getenv("api-key")) #load api-key secret

    benchmark_types.append("leaderboard")
    for benchmark in benchmark_types:
        file_name = f"{benchmark}_results.csv"
        local_path = f"/tmp/{file_name}"
        
        try:
            # Download the file from the specified repo
            api.hf_hub_download(
                repo_id=repo_id,
                repo_type=repo_type,
                filename=file_name,
                local_dir="/tmp",
                token=os.getenv("api-key"),
            )
            print(f"Downloaded {file_name} from {repo_id} to {local_path}")

        except Exception as e:
            print(f"Failed to download {file_name}: {e}")


    return 0


def upload_to_hub(benchmark_types, repo_id="HUBioDataLab/probe-data", repo_type="space"):
    api = HfApi(token=os.getenv("api_key"))  # Requires authentication via HF_TOKEN

    for benchmark in benchmark_types:
        file_name = f"{benchmark}_results.csv"
        local_path = f"/tmp/{file_name}"

        api.upload_file(
            path_or_fileobj=local_path,
            path_in_repo=file_name,
            repo_id=repo_id,
            repo_type=repo_type,
            commit_message=f"Updating {file_name}"
        )
        print(f"Uploaded {local_path} to {repo_id}/{file_name}")

        os.remove(local_path)
        print(f"Removed local file: {local_path}")

    return 0


def save_similarity_output(
    output_dict,
    method_name,
    leaderboard_path="/tmp/leaderboard_results.csv",
    similarity_path="/tmp/similarity_results.csv",
):
    # Load or initialize the DataFrames
    if os.path.exists(leaderboard_path):
        leaderboard_df = pd.read_csv(leaderboard_path)
    else:
        print("Leaderboard file not found!")
        return -1

    if os.path.exists(similarity_path):
        similarity_df = pd.read_csv(similarity_path)
    else:
        print("Similarity file not found!")
        return -1

    if method_name not in similarity_df['Method'].values:
        # Create a new row for the method with default values
        new_row = {col: None for col in similarity_df.columns}
        new_row['Method'] = method_name
        similarity_df = pd.concat([similarity_df, pd.DataFrame([new_row])], ignore_index=True)

    if method_name not in leaderboard_df['Method'].values:
        new_row = {col: None for col in leaderboard_df.columns}
        new_row['Method'] = method_name
        leaderboard_df = pd.concat([leaderboard_df, pd.DataFrame([new_row])], ignore_index=True)

    averages = {}
    for dataset in ['sparse', '200', '500']:
        correlation_values = []
        pvalue_values = []

        for aspect in ['MF', 'BP', 'CC']:
            correlation_key = f"{dataset}_{aspect}_correlation"
            pvalue_key = f"{dataset}_{aspect}_pvalue"

            # Update correlation if present
            if correlation_key in output_dict:
                correlation = output_dict[correlation_key].item()
                correlation_values.append(correlation)
                similarity_df.loc[similarity_df['Method'] == method_name, correlation_key] = correlation
                leaderboard_df.loc[leaderboard_df['Method'] == method_name, f"sim_{correlation_key}"] = correlation

            # Update p-value if present
            if pvalue_key in output_dict:
                pvalue = output_dict[pvalue_key].item()
                pvalue_values.append(pvalue)
                similarity_df.loc[similarity_df['Method'] == method_name, pvalue_key] = pvalue
                leaderboard_df.loc[leaderboard_df['Method'] == method_name, f"sim_{pvalue_key}"] = pvalue

        # Calculate averages if all three aspects are present
        if len(correlation_values) == 3:
            averages[f"{dataset}_Ave_correlation"] = sum(correlation_values) / 3
            similarity_df.loc[similarity_df['Method'] == method_name, f"{dataset}_Ave_correlation"] = averages[f"{dataset}_Ave_correlation"]
            leaderboard_df.loc[leaderboard_df['Method'] == method_name, f"sim_{dataset}_Ave_correlation"] = averages[f"{dataset}_Ave_correlation"]

        if len(pvalue_values) == 3:
            averages[f"{dataset}_Ave_pvalue"] = sum(pvalue_values) / 3
            similarity_df.loc[similarity_df['Method'] == method_name, f"{dataset}_Ave_pvalue"] = averages[f"{dataset}_Ave_pvalue"]
            leaderboard_df.loc[leaderboard_df['Method'] == method_name, f"sim_{dataset}_Ave_pvalue"] = averages[f"{dataset}_Ave_pvalue"]

    leaderboard_df.to_csv(leaderboard_path, index=False)
    similarity_df.to_csv(similarity_path, index=False)

    return 0


def save_function_output(
    model_output, 
    method_name, 
    func_results_path="/tmp/function_results.csv", 
    leaderboard_path="/tmp/leaderboard_results.csv"
):
    # Load or initialize the DataFrames
    if os.path.exists(leaderboard_path):
        leaderboard_df = pd.read_csv(leaderboard_path)
    else:
        print("Leaderboard file not found!")
        return -1

    if os.path.exists(func_results_path):
        func_results_df = pd.read_csv(func_results_path)
    else:
        print("Function file not found!")
        return -1

    if method_name not in func_results_df['Method'].values:
        # Create a new row for the method with default values
        new_row = {col: None for col in func_results_df.columns}
        new_row['Method'] = method_name
        func_results_df = pd.concat([func_results_df, pd.DataFrame([new_row])], ignore_index=True)

    if method_name not in leaderboard_df['Method'].values:
        new_row = {col: None for col in leaderboard_df.columns}
        new_row['Method'] = method_name
        leaderboard_df = pd.concat([leaderboard_df, pd.DataFrame([new_row])], ignore_index=True)

    
    # Storage for averaging in leaderboard results
    metrics_sum = {
        'accuracy': {'BP': [], 'CC': [], 'MF': []},
        'F1': {'BP': [], 'CC': [], 'MF': []},
        'precision': {'BP': [], 'CC': [], 'MF': []},
        'recall': {'BP': [], 'CC': [], 'MF': []}
    }

    # Iterate over each entry in model_output
    for entry in model_output:
        key = entry[0]
        accuracy, f1, precision, recall = entry[1], entry[4], entry[7], entry[10]

        # Parse the key to extract the aspect and datasets
        aspect, dataset1, dataset2 = key.split('_')

        # Save each metric to function_results under its respective column
        func_results_df.loc[func_results_df['Method'] == method_name, f"{aspect}_{dataset1}_{dataset2}_accuracy"] = accuracy
        func_results_df.loc[func_results_df['Method'] == method_name, f"{aspect}_{dataset1}_{dataset2}_F1"] = f1
        func_results_df.loc[func_results_df['Method'] == method_name, f"{aspect}_{dataset1}_{dataset2}_precision"] = precision
        func_results_df.loc[func_results_df['Method'] == method_name, f"{aspect}_{dataset1}_{dataset2}_recall"] = recall

        # Add values for leaderboard averaging
        metrics_sum['accuracy'][aspect].append(accuracy)
        metrics_sum['F1'][aspect].append(f1)
        metrics_sum['precision'][aspect].append(precision)
        metrics_sum['recall'][aspect].append(recall)

    # Calculate averages for each aspect and overall (if all aspects have entries)
    for metric in ['accuracy', 'F1', 'precision', 'recall']:
        for aspect in ['BP', 'CC', 'MF']:
            if metrics_sum[metric][aspect]:
                aspect_average = sum(metrics_sum[metric][aspect]) / len(metrics_sum[metric][aspect])
                leaderboard_df.loc[leaderboard_df['Method'] == method_name, f"func_{aspect}_{metric}"] = aspect_average

        # Calculate overall average if each aspect has entries
        if all(metrics_sum[metric][aspect] for aspect in ['BP', 'CC', 'MF']):
            overall_average = sum(
                sum(metrics_sum[metric][aspect]) / len(metrics_sum[metric][aspect])
                for aspect in ['BP', 'CC', 'MF']
            ) / 3
            leaderboard_df.loc[leaderboard_df['Method'] == method_name, f"func_Ave_{metric}"] = overall_average

    # Save updated DataFrames to CSV
    func_results_df.to_csv(func_results_path, index=False)
    leaderboard_df.to_csv(leaderboard_path, index=False)

    return 0

    
def save_family_output(
    model_output, 
    method_name, 
    leaderboard_path="/tmp/leaderboard_results.csv", 
    family_results_path="/tmp/family_results.csv"
):
    # Load or initialize the DataFrames
    if os.path.exists(leaderboard_path):
        leaderboard_df = pd.read_csv(leaderboard_path)
    else:
        print("Leaderboard file not found!")
        return -1

    if os.path.exists(family_results_path):
        family_results_df = pd.read_csv(family_results_path)
    else:
        print("Family file not found!")
        return -1

    if method_name not in family_results_df['Method'].values:
        # Create a new row for the method with default values
        new_row = {col: None for col in family_results_df.columns}
        new_row['Method'] = method_name
        family_results_df = pd.concat([family_results_df, pd.DataFrame([new_row])], ignore_index=True)

    if method_name not in leaderboard_df['Method'].values:
        new_row = {col: None for col in leaderboard_df.columns}
        new_row['Method'] = method_name
        leaderboard_df = pd.concat([leaderboard_df, pd.DataFrame([new_row])], ignore_index=True)

    # Iterate through the datasets and metrics
    for dataset, metrics in model_output.items():
        for metric, values in metrics.items():
            # Calculate the average for each metric in leaderboard results
            avg_value = sum(values) / len(values) if values else None
            leaderboard_df.loc[leaderboard_df['Method'] == method_name, f"fam_{dataset}_{metric}_ave"] = avg_value

            # Save each fold result for family results
            for i, value in enumerate(values):
                family_results_df.loc[family_results_df['Method'] == method_name, f"{dataset}_{metric}_{i}"] = value

    # Save updated DataFrames to CSV
    leaderboard_df.to_csv(leaderboard_path, index=False)
    family_results_df.to_csv(family_results_path, index=False)

    return 0


def save_affinity_output(
    model_output, 
    method_name, 
    leaderboard_path="/tmp/leaderboard_results.csv", 
    affinity_results_path="/tmp/affinity_results.csv"
):
    # Load or initialize the DataFrames
    if os.path.exists(leaderboard_path):
        leaderboard_df = pd.read_csv(leaderboard_path)
    else:
        print("Leaderboard file not found!")
        return -1

    if os.path.exists(affinity_results_path):
        affinity_results_df = pd.read_csv(affinity_results_path)
    else:
        print("Affinity file not found!")
        return -1

    if method_name not in affinity_results_df['Method'].values:
        # Create a new row for the method with default values
        new_row = {col: None for col in affinity_results_df.columns}
        new_row['Method'] = method_name
        affinity_results_df = pd.concat([affinity_results_df, pd.DataFrame([new_row])], ignore_index=True)

    if method_name not in leaderboard_df['Method'].values:
        new_row = {col: None for col in leaderboard_df.columns}
        new_row['Method'] = method_name
        leaderboard_df = pd.concat([leaderboard_df, pd.DataFrame([new_row])], ignore_index=True)

    # Process 'summary' section for leaderboard results
    summary = model_output.get('summary', {})
    if summary:
        leaderboard_df.loc[leaderboard_df['Method'] == method_name, 'aff_mse_ave'] = summary.get('val_mse_error')
        leaderboard_df.loc[leaderboard_df['Method'] == method_name, 'aff_mae_ave'] = summary.get('val_mae_error')
        leaderboard_df.loc[leaderboard_df['Method'] == method_name, 'aff_corr_ave'] = summary.get('validation_corr')

    # Process 'detail' section for affinity results
    detail = model_output.get('detail', {})
    if detail:
        # Save each 10-fold cross-validation result for mse, mae, and corr
        for i in range(10):
            if 'val_mse_errors' in detail:
                affinity_results_df.loc[affinity_results_df['Method'] == method_name, f"mse_{i}"] = detail['val_mse_errors'][i]
            if 'val_mae_errors' in detail:
                affinity_results_df.loc[affinity_results_df['Method'] == method_name, f"mae_{i}"] = detail['val_mae_errors'][i]
            if 'validation_corrs' in detail:
                affinity_results_df.loc[affinity_results_df['Method'] == method_name, f"corr_{i}"] = detail['validation_corrs'][i]

    # Save updated DataFrames to CSV
    leaderboard_df.to_csv(leaderboard_path, index=False)
    affinity_results_df.to_csv(affinity_results_path, index=False)

    return 0