import json
import os
import shutil
import sys
from collections import defaultdict
from statistics import mean

import pandas as pd
import requests

from constants import BASE_WHISPERKIT_BENCHMARK_URL
from text_normalizer import text_normalizer
from utils import compute_average_wer, download_dataset


def fetch_evaluation_data(url):
    """
    Fetches evaluation data from the given URL.
    :param url: The URL to fetch the evaluation data from.
    :returns: The evaluation data as a dictionary.
    :rauses: sys.exit if the request fails
    """
    response = requests.get(url)
    if response.status_code == 200:
        return json.loads(response.text)
    else:
        sys.exit(f"Failed to fetch WhisperKit evals: {response.text}")


def process_benchmark_file(file_path, dataset_dfs, device_map, results):
    """
    Processes a single benchmark file and updates the results dictionary.
    :param file_path: Path to the benchmark JSON file.
    :param dataset_dfs: Dictionary of DataFrames containing dataset information.
    :param results: Dictionary to store the processed results.
    This function reads a benchmark JSON file, extracts relevant information,
    and updates the results dictionary with various metrics including WER,
    speed, tokens per second, and quality of inference (QoI).
    """    
    with open(file_path, "r") as file:
        test_results = json.load(file)

    if len(test_results) == 0:
        return

    commit_hash_timestamp = file_path.split("/")[-2]
    commit_timestamp, commit_hash = commit_hash_timestamp.split("_")

    first_test_result = test_results[0]
    if first_test_result is None:
        return
    
    filename = file_path.split("/")[-1].strip(".json")
    device, company, model, dataset_dir, timestamp = filename.split("_")
    model = f"{company}_{model}"

    if device not in device_map:
        return
    
    device = device_map[device]
    os_info = first_test_result["staticAttributes"]["os"]

    key = (model, device, os_info, commit_timestamp)
    dataset_name = dataset_dir
    for test_result in test_results:
        if test_result is None:
            continue

        test_info = test_result["testInfo"]
        audio_file_name = test_info["audioFile"]

        dataset_df = dataset_dfs[dataset_name]

        wer_entry = {
            "prediction": text_normalizer(test_info["prediction"]),
            "reference": text_normalizer(test_info["reference"]),
        }
        results[key]["timestamp"] = timestamp
        results[key]["average_wer"].append(wer_entry)

        input_audio_seconds = test_info["timings"]["inputAudioSeconds"]
        full_pipeline = test_info["timings"]["fullPipeline"] / 1000
        time_elapsed = test_result["latencyStats"]["measurements"]["timeElapsed"]
        total_decoding_loops = test_info["timings"]["totalDecodingLoops"]

        results[key]["dataset_speed"][dataset_name][
            "inputAudioSeconds"
        ] += input_audio_seconds
        results[key]["dataset_speed"][dataset_name]["fullPipeline"] += full_pipeline

        results[key]["speed"]["inputAudioSeconds"] += input_audio_seconds
        results[key]["speed"]["fullPipeline"] += full_pipeline

        results[key]["commit_hash"] = commit_hash
        results[key]["commit_timestamp"] = commit_timestamp

        results[key]["dataset_tokens_per_second"][dataset_name][
            "totalDecodingLoops"
        ] += total_decoding_loops
        results[key]["dataset_tokens_per_second"][dataset_name][
            "timeElapsed"
        ] += time_elapsed
        results[key]["tokens_per_second"]["totalDecodingLoops"] += total_decoding_loops
        results[key]["tokens_per_second"]["timeElapsed"] += time_elapsed

        audio = audio_file_name.split(".")[0]
        audio = audio.split("-")[0]
        
        dataset_row = dataset_df.loc[dataset_df["file"].str.contains(audio)].iloc[0]
        reference_wer = dataset_row["wer"]
        prediction_wer = test_info["wer"]

        results[key]["qoi"].append(1 if prediction_wer <= reference_wer * 110 else 0)


def calculate_and_save_performance_results(
    performance_results, performance_output_path
):
    """
    Calculates final performance metrics and saves them to a JSON file.
    :param performance_results: Dictionary containing raw performance data.
    :param performance_output_path: Path to save the processed performance results.
    This function processes the raw performance data, calculates average metrics,
    and writes the final results to a JSON file, with each entry representing
    a unique combination of model, device, and OS.
    """
    not_supported = []
    with open(performance_output_path, "w") as performance_file:
        for key, data in performance_results.items():
            model, device, os_info, timestamp = key
            speed = round(
                data["speed"]["inputAudioSeconds"] / data["speed"]["fullPipeline"], 2
            )

            # if speed < 1.0:
            #     not_supported.append((model, device, os_info))
            #     continue

            performance_entry = {
                "model": model.replace("_", "/"),
                "device": device,
                "os": os_info.replace("_", " "),
                "timestamp": data["timestamp"],
                "speed": speed,
                "tokens_per_second": round(
                    data["tokens_per_second"]["totalDecodingLoops"]
                    / data["tokens_per_second"]["timeElapsed"],
                    2,
                ),
                "dataset_speed": {
                    dataset: round(
                        speed_info["inputAudioSeconds"] / speed_info["fullPipeline"], 2
                    )
                    for dataset, speed_info in data["dataset_speed"].items()
                },
                "dataset_tokens_per_second": {
                    dataset: round(
                        tps_info["totalDecodingLoops"] / tps_info["timeElapsed"], 2
                    )
                    for dataset, tps_info in data["dataset_tokens_per_second"].items()
                },
                "average_wer": compute_average_wer(data["average_wer"]),
                "qoi": round(mean(data["qoi"]), 2),
                "commit_hash": data["commit_hash"],
                "commit_timestamp": data["commit_timestamp"],
            }

            json.dump(performance_entry, performance_file)
            performance_file.write("\n")

    return not_supported


def generate_support_matrix(performance_data_path="dashboard_data/performance_data.json", output_file="dashboard_data/support_data.csv"):
    """
    Generate a support matrix CSV showing model compatibility across devices and OS versions.
    ✅: All tests passed
    ⚠️: Some tests failed
    """
    support_matrix = defaultdict(lambda: defaultdict(lambda: {
        "os_versions": set(),
        "dataset_count": 0
    }))
    
    models = set()
    devices = set()
    
    # Process performance data
    with open(performance_data_path, 'r') as f:
        for line in f:
            entry = json.loads(line)
            model = entry["model"]
            device = entry["device"]  
            os_info = entry["os"]
            
            models.add(model)
            devices.add(device)
            
            support_matrix[model][device]["os_versions"].add(os_info)
            if "dataset_speed" in entry:
                support_matrix[model][device]["dataset_count"] = len(entry["dataset_speed"])
    
    # Create DataFrame with correct headers
    df = pd.DataFrame(columns=['', 'Model'] + [f'"{device}"' for device in sorted(devices)])
    
    # Add each model with its data
    for model in sorted(models):
        row_data = {'': model, 'Model': model}
        
        for device in sorted(devices):
            info = support_matrix[model].get(device, {"dataset_count": 0, "os_versions": set()})
            os_versions = ', '.join(sorted(info["os_versions"]))
            
            if info["dataset_count"] == 0:
                row_data[f'"{device}"'] = "Not Supported"
            elif info["dataset_count"] >= 2:
                row_data[f'"{device}"'] = f"✅ {os_versions}"
            else:
                row_data[f'"{device}"'] = f"⚠️ {os_versions}"
        
        df = pd.concat([df, pd.DataFrame([row_data])], ignore_index=True)
    
    # Save to CSV
    df.to_csv(output_file, index=False)


def main():
    """
    Main function to orchestrate the performance data generation process.
    This function performs the following steps:
    1. Downloads benchmark data if requested.
    2. Fetches evaluation data for various datasets.
    3. Processes benchmark files and summary files.
    4. Calculates and saves performance and support results.
    """
    source_xcresult_repo = "argmaxinc/whisperkit-evals-dataset"
    source_xcresult_subfolder = "benchmark_data/"
    source_xcresult_directory = f"{source_xcresult_repo}/{source_xcresult_subfolder}"
    if len(sys.argv) > 1 and sys.argv[1] == "download":
        try:
            shutil.rmtree(source_xcresult_repo)
        except:
            print("Nothing to remove.")
        download_dataset(
            source_xcresult_repo, source_xcresult_repo, source_xcresult_subfolder
        )

    datasets = {
        "Earnings-22": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
        "LibriSpeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
        "earnings22-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
        "librispeech-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
        "earnings22-12hours": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
        "librispeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
    }

    dataset_dfs = {}
    for dataset_name, url in datasets.items():
        evals = fetch_evaluation_data(url)
        dataset_dfs[dataset_name] = pd.json_normalize(evals["results"])

    performance_results = defaultdict(
        lambda: {
            "average_wer": [],
            "qoi": [],
            "speed": {"inputAudioSeconds": 0, "fullPipeline": 0},
            "tokens_per_second": {"totalDecodingLoops": 0, "timeElapsed": 0},
            "dataset_speed": defaultdict(
                lambda: {"inputAudioSeconds": 0, "fullPipeline": 0}
            ),
            "dataset_tokens_per_second": defaultdict(
                lambda: {"totalDecodingLoops": 0, "timeElapsed": 0}
            ),
            "timestamp": None,
            "commit_hash": None,
            "commit_timestamp": None,
            "test_timestamp": None,
        }
    )

    with open("dashboard_data/device_map.json", "r") as f:
        device_map = json.load(f)

    for subdir, _, files in os.walk(source_xcresult_directory):
        for filename in files:
            file_path = os.path.join(subdir, filename)
            if not filename.endswith(".json"):
                continue
            else:
                process_benchmark_file(file_path, dataset_dfs, device_map, performance_results)
    
    calculate_and_save_performance_results(
        performance_results, "dashboard_data/performance_data.json"
    )

    generate_support_matrix()


if __name__ == "__main__":
    main()