import torch
from PIL import Image
from torchvision import transforms, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from sentence_transformers import SentenceTransformer
import random
import urllib.parse
import torch.nn as nn
from sklearn.metrics import classification_report
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gradio as gr
from io import BytesIO

# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Data transformation
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets for enriched prompts
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
dataset_desc.columns = dataset_desc.columns.str.lower()
style_desc = pd.read_csv("style_desc.csv", delimiter=';')
style_desc.columns = style_desc.columns.str.lower()

# Function to enrich prompts with custom data
def enrich_prompt(artist, style):
    artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values
    style_info = style_desc.loc[style_desc['style'] == style, 'description'].values

    artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available."
    style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available."

    return f"{artist_details} This work exemplifies {style_details}."

# Custom dataset for ResNet18
class ArtDataset:
    def __init__(self, csv_file):
        self.annotations = pd.read_csv(csv_file)
        self.train_data = self.annotations[self.annotations['subset'] == 'train']
        self.test_data = self.annotations[self.annotations['subset'] == 'test']
        self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
        self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}

    def get_style_and_artist_mappings(self):
        return self.label_map_style, self.label_map_artist

    def get_train_test_split(self):
        return self.train_data, self.test_data

# DualOutputResNet model with Dropout
class DualOutputResNet(nn.Module):
    def __init__(self, num_styles, num_artists, dropout_rate=0.5):
        super(DualOutputResNet, self).__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc_style = nn.Linear(num_features, num_styles)
        self.fc_artist = nn.Linear(num_features, num_artists)

    def forward(self, x):
        features = self.backbone(x)
        features = self.dropout(features)
        style_output = self.fc_style(features)
        artist_output = self.fc_artist(features)
        return style_output, artist_output

# Load dataset
csv_file = "cleaned_classes.csv"
dataset = ArtDataset(csv_file)
label_map_style, label_map_artist = dataset.get_style_and_artist_mappings()
train_data, test_data = dataset.get_train_test_split()
num_styles = len(label_map_style)
num_artists = len(label_map_artist)

# Model setup
model_resnet = DualOutputResNet(num_styles, num_artists).to(device)
optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Load SentenceTransformer model
clip_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1').to(device)

# Load GPT-Neo and set padding token
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Set pad_token to eos_token
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)


def generate_description(image):
    image_resnet = data_transforms(image).unsqueeze(0).to(device)

    model_resnet.eval()
    with torch.no_grad():
        outputs_style, outputs_artist = model_resnet(image_resnet)
        _, predicted_style_idx = torch.max(outputs_style, 1)
        _, predicted_artist_idx = torch.max(outputs_artist, 1)

    idx_to_style = {v: k for k, v in label_map_style.items()}
    idx_to_artist = {v: k for k, v in label_map_artist.items()}
    predicted_style = idx_to_style[predicted_style_idx.item()]
    predicted_artist = idx_to_artist[predicted_artist_idx.item()]

    enriched_prompt = enrich_prompt(predicted_artist, predicted_style)
    full_prompt = (
        f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} "
        "Describe its distinctive features, considering both the artist's techniques and the artistic style."
    )

    input_ids = tokenizer.encode(full_prompt, return_tensors="pt", padding=True).to(device)
    attention_mask = input_ids != tokenizer.pad_token_id

    output = model_gptneo.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=250,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.5,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id
    )

    description_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return predicted_style, predicted_artist, description_text


# Gradio interface
def gradio_interface(image):
    if image is None:
        return "No image provided. Please upload an image."

    if isinstance(image, BytesIO):
        image = Image.open(image).convert("RGB")
    else:
        image = Image.open(image).convert("RGB")

    predicted_style, predicted_artist, description = generate_description(image)
    return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"

iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="filepath"),
    outputs="text",
    title="AI Artwork Analysis",
    description="Upload an image to predict its artistic style and creator, and generate a detailed description."
)

if __name__ == "__main__":
    iface.launch()