import json
import random

import httpx
import polars as pl
from huggingface_hub import list_datasets
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio


# Initialize the HTTP client
client = httpx.AsyncClient(timeout=60, http2=True)


async def generate_dataset_prompt(dataset_name, num_rows=2):
    try:
        base_url = "https://datasets-server.huggingface.co"

        # Get splits and configs
        splits_url = f"{base_url}/splits?dataset={dataset_name}"
        splits_response = await client.get(splits_url)
        splits_data = splits_response.json()

        if not splits_data.get("splits"):
            return None

        # Get the first config and split
        first_split = splits_data["splits"][0]
        config_name = first_split["config"]
        split_name = first_split["split"]

        # Get dataset info for the specific config
        info_url = f"{base_url}/info?dataset={dataset_name}&config={config_name}"
        info_response = await client.get(info_url)
        info_data = info_response.json()

        # Get first rows for the specific config and split
        first_rows_url = f"{base_url}/first-rows?dataset={dataset_name}&config={config_name}&split={split_name}"
        first_rows_response = await client.get(first_rows_url)
        first_rows_data = first_rows_response.json()

        # Get size information
        size_url = f"{base_url}/size?dataset={dataset_name}"
        size_response = await client.get(size_url)
        size_data = size_response.json()

        # Extract relevant information
        dataset_info = info_data.get("dataset_info", {})
        features = dataset_info.get("features", {})
        splits = dataset_info.get("splits", {})

        # Calculate total examples and size
        total_examples = sum(split.get("num_examples", 0) for split in splits.values())
        total_size = (
            size_data.get("size", {})
            .get("dataset", {})
            .get("num_bytes_original_files", 0)
        )

        # Format features
        def format_feature(name, details):
            if isinstance(details, dict):
                feature_type = details.get(
                    "dtype", details.get("_type", "unknown type")
                )
            elif isinstance(details, list):
                feature_type = "list"
            else:
                feature_type = str(type(details).__name__)
            return f"- {name} ({feature_type})"

        formatted_features = "\n".join(
            format_feature(name, details) for name, details in features.items()
        )

        # Format sample data (specified number of rows)
        sample_data = json.dumps(first_rows_data.get("rows", [])[:num_rows], indent=2)

        # Create the formatted prompt
        prompt = f"""
Dataset: "{dataset_name}"

Features:
{formatted_features}

Splits and Configs:
{', '.join(f"{split['config']}/{split['split']}" for split in splits_data['splits'])}

Size Statistics:
Total Examples: {total_examples}
Split Sizes: {', '.join(f"{split}: {info['num_examples']}" for split, info in splits.items())}

Data Sample ({num_rows} rows out of {total_examples} total):
{sample_data}
"""

        return prompt.strip()
    except Exception as e:
        print(f"Error for {dataset_name}: {e}")
        return None


async def process_batch(batch):
    results = await tqdm_asyncio.gather(
        *[generate_dataset_prompt(dataset) for dataset in batch], leave=False
    )
    return [
        (dataset_id, prompt)
        for dataset_id, prompt in zip(batch, results)
        if prompt is not None
    ]


async def prep_data(sample_size=200_000, min_likes=1):
    # Load the dataset containing dataset IDs
    df = pl.read_parquet(
        "hf://datasets/davanstrien/dataset-viewer-descriptions-processed/data/train-00000-of-00001.parquet"
    )
    # remove datasets that are already in the train or test set we can remove this later once the model works okay

    in_train_or_test = set(df["dataset_id"].unique().to_list())

    # Get all datasets
    datasets = [
        dataset for dataset in list_datasets() if dataset.id not in in_train_or_test
    ]
    # filter to datasets with 1 or more likes
    if min_likes:
        datasets = [dataset for dataset in datasets if dataset.likes >= min_likes]
    datasets = [dataset.id for dataset in datasets]
    # Sample datasets (adjust the number as needed)
    datasets = random.sample(datasets, min(sample_size, len(datasets)))

    # Process datasets in batches of 100
    batch_size = 500
    all_results = []

    for i in tqdm(range(0, len(datasets), batch_size), desc="Processing batches"):
        batch = datasets[i : i + batch_size]
        batch_results = await process_batch(batch)
        all_results.extend(batch_results)

        # Optional: Save intermediate results
        if len(all_results) % 1000 == 0:
            intermediate_df = pl.DataFrame(
                {
                    "dataset_id": [row[0] for row in all_results],
                    "formatted_prompt": [row[1] for row in all_results],
                }
            )
            intermediate_df.write_parquet(
                f"dataset_prompts_intermediate_{len(all_results)}.parquet"
            )

    return pl.DataFrame(
        {
            "dataset_id": [row[0] for row in all_results],
            "formatted_prompt": [row[1] for row in all_results],
        }
    )