import gradio as gr
import numpy as np
import tensorflow as tf
import logging
from PIL import Image
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess
import scipy.fftpack
import time
import clip
import torch

# Set up logging
logging.basicConfig(level=logging.INFO)

# Load models
resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
vgg_model = VGG16(weights='imagenet', include_top=False, pooling='avg')
clip_model, preprocess_clip = clip.load("ViT-B/32", device="cpu")

# Preprocess function
def preprocess_img(img_path, target_size=(224, 224), preprocess_func=resnet_preprocess):
    start_time = time.time()
    img = keras_image.load_img(img_path, target_size=target_size)
    img_array = keras_image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_func(img_array)
    logging.info(f"Image preprocessed in {time.time() - start_time:.4f} seconds")
    return img_array

# Feature extraction function
def extract_features(img_path, model, preprocess_func):
    img_array = preprocess_img(img_path, preprocess_func=preprocess_func)
    start_time = time.time()
    features = model.predict(img_array)
    logging.info(f"Features extracted in {time.time() - start_time:.4f} seconds")
    return features.flatten()

# Calculate cosine similarity
def cosine_similarity(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

# pHash related functions
def phashstr(image, hash_size=8, highfreq_factor=4):
    img_size = hash_size * highfreq_factor
    image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
    pixels = np.asarray(image)
    dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
    dctlowfreq = dct[:hash_size, :hash_size]
    med = np.median(dctlowfreq)
    diff = dctlowfreq > med
    return _binary_array_to_hex(diff.flatten())

def _binary_array_to_hex(arr):
    h = 0
    s = []
    for i, v in enumerate(arr):
        if v:
            h += 2**(i % 8)
        if (i % 8) == 7:
            s.append(hex(h)[2:].rjust(2, '0'))
            h = 0
    return ''.join(s)

def hamming_distance(hash1, hash2):
    if len(hash1) != len(hash2):
        raise ValueError("Hashes must be of the same length")
    return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))

def hamming_to_similarity(distance, hash_length):
    return (1 - distance / hash_length) * 100

# CLIP related functions
def extract_clip_features(image_path, model, preprocess):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to("cpu")
    with torch.no_grad():
        features = model.encode_image(image)
    return features.cpu().numpy().flatten()

# Main function
def compare_images(image1, image2, method):
    similarity = None
    start_time = time.time()
    if method == 'pHash':
        img1 = Image.open(image1)
        img2 = Image.open(image2)
        hash1 = phashstr(img1)
        hash2 = phashstr(img2)
        distance = hamming_distance(hash1, hash2)
        similarity = hamming_to_similarity(distance, len(hash1) * 4)
    elif method == 'ResNet50':
        features1 = extract_features(image1, resnet_model, resnet_preprocess)
        features2 = extract_features(image2, resnet_model, resnet_preprocess)
        similarity = cosine_similarity(features1, features2)
    elif method == 'VGG16':
        features1 = extract_features(image1, vgg_model, vgg_preprocess)
        features2 = extract_features(image2, vgg_model, vgg_preprocess)
        similarity = cosine_similarity(features1, features2)
    elif method == 'CLIP':
        features1 = extract_clip_features(image1, clip_model, preprocess_clip)
        features2 = extract_clip_features(image2, clip_model, preprocess_clip)
        similarity = cosine_similarity(features1, features2)
    
    logging.info(f"Image comparison using {method} completed in {time.time() - start_time:.4f} seconds")
    return similarity

# Gradio interface
demo = gr.Interface(
    fn=compare_images,
    inputs=[
        gr.Image(type="filepath", label="Upload First Image"),
        gr.Image(type="filepath", label="Upload Second Image"),
        gr.Radio(["pHash", "ResNet50", "VGG16", "CLIP"], label="Select Comparison Method")
    ],
    outputs=gr.Textbox(label="Similarity"),
    title="Image Similarity Comparison",
    description="Upload two images and select the comparison method.",
    examples=[
        ["Snipaste_2024-05-31_16-18-31.jpg", "Snipaste_2024-05-31_16-18-52.jpg"],
        ["example1.png", "example2.png"]
    ]
)

demo.launch()