import os 

from fastapi import FastAPI, Request, Response
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf

from datasets import load_dataset
from huggingface_hub import push_to_hub_keras, from_pretrained_keras

KEY = os.environ.get("WEBHOOK_SECRET")

app = FastAPI()

def to_numpy(examples):
    examples["pixel_values"] = [np.array(image.convert('1')) for image in examples["image"]]
    return examples

def preprocess():
    train_dataset = load_dataset("active-learning/labeled_samples")["train"]
    train_dataset = train_dataset.map(to_numpy, batched=True)

    test_dataset = load_dataset("active-learning/test_mnist")["test"]
    test_dataset = test_dataset.map(to_numpy, batched=True)
    
    x_train = train_dataset["pixel_values"]
    y_train = train_dataset["label"]
    
    x_test = test_dataset["pixel_values"]
    y_test = test_dataset["label"]

    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)

    num_classes = 10

    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)

    return x_train, y_train, x_test, y_test

def train():
    input_shape = (28, 28, 1)
    x_train, y_train, x_test, y_test = preprocess()
    num_classes = 10

    model = keras.Sequential(
        [
            keras.Input(shape=input_shape),
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dropout(0.5),
            layers.Dense(num_classes, activation="softmax"),
        ]
    )
    
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    model.fit(x_train, y_train, batch_size=128, epochs=15, validation_split=0.1)

    score = model.evaluate(x_test, y_test, verbose=0)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])

    push_to_hub_keras(model, "active-learning/mnist_classifier")

def find_samples_to_label():
    loaded_model = from_pretrained_keras("active-learning/mnist_classifier")
    loaded_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    
    unlabeled_data = load_dataset("active-learning/unlabeled_samples")["train"]
    processed_data = unlabeled_data.map(to_numpy, batched=True)
    processed_data = processed_data["pixel_values"]
    processed_data = tf.expand_dims(processed_data, -1)

    # Get all predictions
    # And then get the 5 samples with the lowest prediction score
    preds = loaded_model.predict(processed_data)
    top_pred_confs = 1 - np.max(preds, axis=1)
    idx_to_label = np.argpartition(top_pred_confs, -5)[-5:]

    # Upload samples to the dataset to label
    to_label_data = unlabeled_data.select(idx_to_label)
    to_label_data.push_to_hub("active-learning/to_label_samples")

    # Remove from the pool of samples
    unlabeled_data = unlabeled_data.select(
        (
            i for i in range(len(unlabeled_data)) 
            if i not in set(idx_to_label)
        )
    )
    unlabeled_data.push_to_hub("active-learning/unlabeled_samples")

@app.get("/")
def read_root():
    data = """
    <h2 style="text-align:center">Active Learning Trainer</h2>
    <p style="text-align:center">This is a demo app showing how to webhooks to do Active Learning.</p>
    """
    return Response(content=data, media_type="text/html")

@app.post("/webhook")
async def webhook(request: Request):
    print("Received request")
    if request.headers.get("X-Webhook-Secret") is None:
        return Response("No secret", status_code=401)
    if request.headers.get("X-Webhook-Secret") != KEY:
        return Response("Invalid secret", status_code=401)
    data = await request.json()
    print("Webhook received!")
    train() 
    find_samples_to_label()
    return "Model trained!"