import json
import gradio as gr
import torch
import PIL
import os
from models import ViTClassifier
from datasets import load_dataset
from transformers import TrainingArguments, ViTConfig, ViTForImageClassification
from torchvision import transforms
import pandas as pd

def load_config(config_path):
    with open(config_path, 'r') as f:
        config = json.load(f)
    print("Config Loaded:", config)  # Debugging
    return config

def load_model(config, device='cuda'):
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    ckpt = torch.load(config['checkpoint_path'], map_location=device)
    print("Checkpoint Loaded:", ckpt.keys())  # Debugging
    model = ViTClassifier(config, device=device, dtype=torch.float32)
    print("Model Loaded:", model)  # Debugging
    model.load_state_dict(ckpt['model'])
    return model.to(device).eval()

def prepare_model_for_push(model, config):
    # Create a VisionTransformerConfig
    vit_config = ViTConfig(
        image_size=config['model']['input_size'],
        patch_size=config['model']['patch_size'],
        hidden_size=config['model']['hidden_size'],
        num_heads=config['model']['num_attention_heads'],
        num_layers=config['model']['num_hidden_layers'],
        mlp_ratio=4,  # Common default for ViT
        hidden_dropout_prob=config['model']['hidden_dropout_prob'],
        attention_probs_dropout_prob=config['model']['attention_probs_dropout_prob'],
        layer_norm_eps=config['model']['layer_norm_eps'],
        num_classes=config['model']['num_classes']
    )
    # Create a VisionTransformer model
    vit_model = ViTForImageClassification(vit_config)
    # Copy the weights from your custom model to the VisionTransformer model
    state_dict = vit_model.state_dict()
    for key in state_dict.keys():
        if key in model.state_dict():
            state_dict[key] = model.state_dict()[key]
    vit_model.load_state_dict(state_dict)
    return vit_model, vit_config

def run_inference(input_image, model):
    print("Input Image Type:", type(input_image))  # Debugging
    # Directly use the PIL Image object
    fake_prob = model.forward(input_image).item()
    result_description = get_result_description(fake_prob)
    return {
        "Fake Probability": fake_prob,
        "Result Description": result_description
    }

def get_result_description(fake_prob):
    if fake_prob > 0.5:
        return "The image is likely a fake."
    else:
        return "The image is likely real."

def run_evaluation(dataset_name, model, config, device):
    dataset = load_dataset(dataset_name)
    eval_df, accuracy = evaluate_model(model, dataset, config, device)
    return accuracy, eval_df.to_csv(index=False)

def evaluate_model(model, dataset, config, device):
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    model.to(device).eval()
    norm_mean = config['preprocessing']['norm_mean']
    norm_std = config['preprocessing']['norm_std']
    resize_size = config['preprocessing']['resize_size']
    crop_size = config['preprocessing']['crop_size']
    augment_list = [
        transforms.Resize(resize_size),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std),
        transforms.ConvertImageDtype(torch.float32),
    ]
    preprocess = transforms.Compose(augment_list)
    true_labels = []
    predicted_probs = []
    predicted_labels = []
    with torch.no_grad():
        for sample in dataset:
            image = sample['image']
            label = sample['label']
            image = preprocess(image).unsqueeze(0).to(device)
            output = model.forward(image)
            prob = output.item()
            true_labels.append(label)
            predicted_probs.append(prob)
            predicted_labels.append(1 if prob > 0.5 else 0)
    eval_df = pd.DataFrame({
        'True Label': true_labels,
        'Predicted Probability': predicted_probs,
        'Predicted Label': predicted_labels
    })
    accuracy = (eval_df['True Label'] == eval_df['Predicted Label']).mean()
    return eval_df, accuracy

def main():
    # Load configuration
    config_path = "config.json"
    config = load_config(config_path)
    # Load model
    device = config['device']
    model = load_model(config, device=device)
    # Define Gradio interface for inference
    def gradio_interface(input_image):
        return run_inference(input_image, model)
    # Create Gradio Tabs
    with gr.Blocks() as demo:
        gr.Markdown("# Community Forensics Dataset Trained Detection Model")
        with gr.Tab("Image Inference"):
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(type="pil", label="Upload Image")
                with gr.Column():
                    output = gr.JSON(label="Classification Result")
            input_image.change(fn=gradio_interface, inputs=input_image, outputs=output)
    # Launch the Gradio app
    demo.launch()

if __name__ == "__main__":
    main()