import os

import gradio as gr
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms

from templates import openai_imagenet_template

hf_token = os.getenv("HF_TOKEN")
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")

model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

preprocess_img = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


@torch.no_grad()
def get_txt_features(classnames, templates):
    all_features = []
    for classname in classnames:
        txts = [template(classname) for template in templates]
        txts = tokenizer(txts).to(device)
        txt_features = model.encode_text(txts)
        txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
        txt_features /= txt_features.norm()
        all_features.append(txt_features)
    all_features = torch.stack(all_features, dim=1)
    return all_features


@torch.no_grad()
def predict(img, classes: list[str]) -> dict[str, float]:
    classes = [cls.strip() for cls in classes if cls.strip()]
    txt_features = get_txt_features(classes, openai_imagenet_template)

    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
    probs = F.softmax(logits, dim=0).to("cpu").tolist()
    return {cls: prob for cls, prob in zip(classes, probs)}


def hierarchical_predict(img) -> list[str]:
    """
    Predicts from the top of the tree of life down to the species.
    """
    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    breakpoint()


def run(img, cls_str: str) -> dict[str, float]:
    breakpoint()
    if cls_str:
        classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
        return predict(img, classes)
    else:
        return hierarchical_predict(img)


if __name__ == "__main__":
    print("Starting.")
    model = create_model(model_str, output_dict=True, require_pretrained=True)
    model = model.to(device)
    print("Created model.")

    model = torch.compile(model)
    print("Compiled model.")

    tokenizer = get_tokenizer(tokenizer_str)

    demo = gr.Interface(
        fn=run,
        inputs=[
            gr.Image(shape=(224, 224)),
            gr.Textbox(
                placeholder="dog\ncat\n...",
                lines=3,
                label="Classes",
                show_label=True,
                info="If empty, will predict from the entire tree of life.",
            ),
        ],
        outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
        allow_flagging="manual",
        flagging_options=["Incorrect", "Other"],
        flagging_callback=hf_writer,
    )

    demo.launch()