# pylint: disable=no-member
import gradio as gr
import requests
from huggingface_hub import HfApi
from huggingface_hub.errors import RepositoryNotFoundError
import pandas as pd
import plotly.express as px
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from collections import defaultdict
import numpy as np

HF_API = HfApi()


def apply_power_scaling(sizes: list, exponent=0.2) -> list:
    """Apply custom power scaling to the sizes."""
    return [size**exponent if size is not None else 0 for size in sizes]


def count_chunks(sizes: list | int) -> list:
    """Count the number of chunks, which are 64KB each in size; always roundup"""
    if isinstance(sizes, int):
        return int(np.ceil(sizes / 64_000))
    return [int(np.ceil(size / 64_000)) if size is not None else 0 for size in sizes]


def build_hierarchy(siblings: list) -> dict:
    """Builds a hierarchical structure from the list of RepoSibling objects."""
    hierarchy = defaultdict(dict)

    for sibling in siblings:
        path_parts = sibling.rfilename.split("/")
        size = sibling.lfs.size if sibling.lfs else sibling.size

        current_level = hierarchy
        for part in path_parts[:-1]:
            current_level = current_level.setdefault(part, {})
        current_level[path_parts[-1]] = size

    return hierarchy


def calculate_directory_sizes(hierarchy):
    """Recursively calculates the size of each directory as the sum of its contents."""
    total_size = 0

    for key, value in hierarchy.items():
        if isinstance(value, dict):
            dir_size = calculate_directory_sizes(value)
            hierarchy[key] = {
                "__size__": dir_size,
                **value,
            }
            total_size += dir_size
        else:
            total_size += value

    return total_size


def build_full_path(current_parent, key):
    return f"{current_parent}/{key}" if current_parent else key


def flatten_hierarchy(hierarchy, root_name="Repository"):
    """Flatten a nested dictionary into Plotly-compatible treemap data with a defined root node."""
    labels = []
    parents = []
    sizes = []
    ids = []

    # Recursively process the hierarchy
    def process_level(current_hierarchy, current_parent):
        for key, value in current_hierarchy.items():
            full_path = build_full_path(current_parent, key)
            if isinstance(value, dict) and "__size__" in value:
                # Handle directories
                dir_size = value.pop("__size__")
                labels.append(key)
                parents.append(current_parent)
                sizes.append(dir_size)
                ids.append(full_path)
                process_level(value, full_path)
            else:
                # Handle files
                labels.append(key)
                parents.append(current_parent)
                sizes.append(value)
                ids.append(full_path)

    # Add the root node
    total_size = calculate_directory_sizes(hierarchy)
    labels.append(root_name)
    parents.append("")
    sizes.append(total_size)
    ids.append(root_name)

    # Process the hierarchy
    process_level(hierarchy, root_name)

    return labels, parents, sizes, ids


def visualize_repo_treemap(r_info: dict, r_id: str) -> px.treemap:
    """Visualizes the repository as a treemap with directory sizes and human-readable tooltips."""
    siblings = r_info.siblings
    hierarchy = build_hierarchy(siblings)

    # Calculate directory sizes
    calculate_directory_sizes(hierarchy)

    # Flatten the hierarchy for Plotly
    labels, parents, sizes, ids = flatten_hierarchy(hierarchy, r_id)

    # Scale for vix
    scaled_sizes = apply_power_scaling(sizes)

    # Format the original sizes using the helper function
    formatted_sizes = [
        (format_repo_size(size) if size is not None else None) for size in sizes
    ]

    chunks = count_chunks(sizes)
    colors = scaled_sizes[:]
    colors[0] = -1
    max_value = max(scaled_sizes)
    normalized_colors = [value / max_value if value > 0 else 0 for value in colors]

    # Define the colorscale; mimics the plasma scale
    colorscale = [
        [0.0, "#0d0887"],
        [0.5, "#bd3786"],
        [1.0, "#f0f921"],
    ]

    # Create the treemap
    fig = px.treemap(
        names=labels,
        parents=parents,
        values=scaled_sizes,
        color=normalized_colors,
        color_continuous_scale=colorscale,
        title=f"{r_id} by Chunks",
        custom_data=[formatted_sizes, chunks],
        height=1000,
        ids=ids,
    )

    fig.update_traces(marker={"colors": ["lightgrey"] + normalized_colors[1:]})

    # Add subtitle by updating the layout
    fig.update_layout(
        title={
            "text": f"{r_id} file and chunk treemap<br><span style='font-size:14px;'>Color represents size in bytes/chunks.</span>",
            "x": 0.5,
            "xanchor": "center",
        },
        coloraxis_showscale=False,
    )

    # Customize the hover template
    fig.update_traces(
        hovertemplate=(
            "<b>%{label}</b><br>"
            "Size: %{customdata[0]}<br>"
            "# of Chunks: %{customdata[1]}"
        )
    )
    fig.update_traces(root_color="lightgrey")

    return fig


def format_repo_size(r_size: int) -> str:
    """
    Convert a repository size in bytes to a human-readable string with appropriate units.

    Args:
        r_size (int): The size of the repository in bytes.

    Returns:
        str: The formatted size string with appropriate units (B, KB, MB, GB, TB, PB).
    """
    units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"}
    order = 0
    while r_size >= 1024 and order < len(units) - 1:
        r_size /= 1024
        order += 1
    return f"{r_size:.2f} {units[order]}"


def repo_files(r_type: str, r_id: str) -> dict:
    r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True)
    fig = visualize_repo_treemap(r_info, r_id)
    files = {}
    for sibling in r_info.siblings:
        ext = sibling.rfilename.split(".")[-1]
        if ext in files:
            files[ext]["size"] += sibling.size
            files[ext]["chunks"] += count_chunks(sibling.size)
            files[ext]["count"] += 1
        else:
            files[ext] = {}
            files[ext]["size"] = sibling.size
            files[ext]["chunks"] = count_chunks(sibling.size)
            files[ext]["count"] = 1
    return files, fig


def repo_size(r_type, r_id):
    try:
        r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type)
    except RepositoryNotFoundError:
        gr.Warning(f"Repository is gated, branch information for {r_id} not available.")
        return {}
    repo_sizes = {}
    for branch in r_refs.branches:
        try:
            response = requests.get(
                f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}",
                timeout=1000,
            )
            response = response.json()
        except Exception:
            response = {}
        if response.get("error") and (
            "restricted" in response.get("error") or "gated" in response.get("error")
        ):
            gr.Warning(f"Branch information for {r_id} not available.")
            return {}
        size = response.get("size")
        if size is not None:
            repo_sizes[branch.name] = {
                "size_in_bytes": size,
                "size_in_chunks": count_chunks(size),
            }

    return repo_sizes


def get_repo_info(r_type, r_id):
    try:
        repo_sizes = repo_size(r_type, r_id)
        repo_files_info, treemap_fig = repo_files(r_type, r_id)
    except RepositoryNotFoundError:
        gr.Warning(
            "Repository not found. Make sure you've entered a valid repo ID and type that corresponds to the repository."
        )
        return (
            gr.Row(visible=False),
            gr.Dataframe(visible=False),
            gr.Plot(visible=False),
            gr.Row(visible=False),
            gr.Dataframe(visible=False),
        )

    # check if repo_sizes is just {}
    if not repo_sizes:
        r_sizes_component = gr.Dataframe(visible=False)
        b_block = gr.Row(visible=False)
    else:
        r_sizes_df = pd.DataFrame(repo_sizes).T.reset_index(names="branch")
        r_sizes_df["formatted_size"] = r_sizes_df["size_in_bytes"].apply(
            format_repo_size
        )
        r_sizes_df.columns = ["Branch", "size_in_bytes", "Chunks", "Size"]
        r_sizes_component = gr.Dataframe(
            value=r_sizes_df[["Branch", "Size", "Chunks"]], visible=True
        )
        b_block = gr.Row(visible=True)

    rf_sizes_df = (
        pd.DataFrame(repo_files_info)
        .T.reset_index(names="ext")
        .sort_values(by="size", ascending=False)
    )
    rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size)
    rf_sizes_df.columns = ["Extension", "bytes", "Chunks", "Count", "Size"]
    return (
        gr.Row(visible=True),
        gr.Dataframe(
            value=rf_sizes_df[["Extension", "Count", "Size", "Chunks"]],
            visible=True,
        ),
        # gr.Plot(rf_sizes_plot, visible=True),
        gr.Plot(treemap_fig, visible=True),
        b_block,
        r_sizes_component,
    )


with gr.Blocks(theme="ocean") as demo:
    gr.Markdown("# Chunking Repos")
    gr.Markdown(
        "Search for a model or dataset repository using the autocomplete below, select the repository type, and get back information about the repository's contents including the [number of chunks each file might be split into with Xet backed storage](https://huggingface.co/blog/from-files-to-chunks)."
    )
    with gr.Blocks():
        # repo_id = gr.Textbox(label="Repository ID", placeholder="123456")
        repo_id = HuggingfaceHubSearch(
            label="Hub Repository Search (enter user, organization, or repository name to start searching)",
            placeholder="Search for model or dataset repositories on Huggingface",
            search_type=["model", "dataset"],
        )
        repo_type = gr.Radio(
            choices=["model", "dataset"],
            label="Repository Type",
            value="model",
        )
        search_button = gr.Button(value="Search")
    with gr.Blocks():
        with gr.Row(visible=False) as results_block:
            with gr.Column():
                gr.Markdown("## Repo Info")
                gr.Markdown(
                    "Hover over any file or directory to see it's size in bytes and total number of chunks required to store it in Xet storage."
                )
                file_info_plot = gr.Plot(visible=False)
                with gr.Row(visible=False) as branch_block:
                    with gr.Column():
                        gr.Markdown("### Branch Sizes")
                        gr.Markdown(
                            "The size of each branch in the repository and how many chunks it might need (assuming no dedupe)."
                        )
                        branch_sizes = gr.Dataframe(visible=False)
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("### File Sizes")
                        gr.Markdown(
                            "The cumulative size of each filetype in the repository (in the `main` branch) and how many chunks they might need (assuming no dedupe)."
                        )
                        file_info = gr.Dataframe(visible=False)
                    # file_info_plot = gr.Plot(visible=False)

    search_button.click(
        get_repo_info,
        inputs=[repo_type, repo_id],
        outputs=[results_block, file_info, file_info_plot, branch_block, branch_sizes],
    )

demo.launch()