import asyncio
import logging

import chromadb
import requests
import stamina
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import InferenceClient
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import thread_map


from prep_viewer_data import prep_data
from utils import get_chroma_client

# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

EMBEDDING_MODEL_NAME = "davanstrien/query-to-dataset-viewer-descriptions"
EMBEDDING_MODEL_REVISION = "07c71d97861a73695f0c53cd6b4b32980007d908"
INFERENCE_MODEL_URL = (
    "https://ecg0by60w2vo9j8h.us-east-1.aws.endpoints.huggingface.cloud"
)


def initialize_clients():
    logger.info("Initializing clients")
    chroma_client = get_chroma_client()
    inference_client = InferenceClient(
        INFERENCE_MODEL_URL,
    )
    return chroma_client, inference_client


def create_collection(chroma_client):
    logger.info("Creating or getting collection")
    embedding_function = SentenceTransformerEmbeddingFunction(
        model_name=EMBEDDING_MODEL_NAME,
        trust_remote_code=True,
        revision=EMBEDDING_MODEL_REVISION,
    )
    logger.info(f"Embedding function: {embedding_function}")
    logger.info(f"Embedding model name: {EMBEDDING_MODEL_NAME}")
    logger.info(f"Embedding model revision: {EMBEDDING_MODEL_REVISION}")
    return chroma_client.create_collection(
        name="dataset-viewer-descriptions",
        get_or_create=True,
        embedding_function=embedding_function,
        metadata={"hnsw:space": "cosine"},
    )


@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text, client):
    text = text[:8192]
    return client.feature_extraction(text)


def embed_and_upsert_datasets(
    dataset_rows_and_ids: list[dict[str, str]],
    collection: chromadb.Collection,
    inference_client: InferenceClient,
    batch_size: int = 100,
):
    logger.info(
        f"Embedding and upserting {len(dataset_rows_and_ids)} datasets for viewer data"
    )
    for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
        batch = dataset_rows_and_ids[i : i + batch_size]
        ids = []
        documents = []
        for item in batch:
            ids.append(item["dataset_id"])
            documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}")
        results = thread_map(
            lambda doc: embed_card(doc, inference_client), documents, leave=False
        )
        logger.info(f"Results: {len(results)}")
        collection.upsert(
            ids=ids,
            embeddings=[embedding.tolist()[0] for embedding in results],
        )
        logger.debug(f"Processed batch {i//batch_size + 1}")


async def refresh_viewer_data(sample_size=200_000, min_likes=2):
    logger.info(
        f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
    )
    chroma_client, inference_client = initialize_clients()
    collection = create_collection(chroma_client)
    logger.info("Collection created successfully")
    logger.info("Preparing data")
    df = await prep_data(sample_size=sample_size, min_likes=min_likes)
    df.write_parquet("viewer_data.parquet")
    if df is not None:
        logger.info("Data prepared successfully")
        logger.info(f"Data: {df}")

    dataset_rows_and_ids = df.to_dicts()

    logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
    embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client)
    logger.info("Refresh completed successfully")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    asyncio.run(refresh_viewer_data())