import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.preprocessing import image

# Load the pre-trained MobileNetV2 model for feature extraction
model = MobileNetV2(weights='imagenet', include_top=False, pooling='avg')

def preprocess_image(img):
    """Preprocess the image to fit MobileNetV2 requirements"""
    img = img.resize((224, 224))  # Resize the image to 224x224
    img_array = image.img_to_array(img)  # Convert to numpy array
    img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
    return preprocess_input(img_array)  # Preprocess the image

def compare_images(image1, image2):
    """Compare two images using cosine similarity of their features"""
    # Extract features for both images
    features1 = model.predict(preprocess_image(image1))
    features2 = model.predict(preprocess_image(image2))

    # Compute cosine similarity between features
    similarity = np.dot(features1, features2.T)
    similarity /= (np.linalg.norm(features1) * np.linalg.norm(features2))
    return {'Similarity': float(similarity)}

# Setup Gradio interface
iface = gr.Interface(
    fn=compare_images,
    inputs=[gr.components.Image(), gr.components.Image()],  # Updated to remove 'shape'
    outputs=gr.components.Label(),
    title="Image Similarity Checker",
    description="Upload two images to compare their similarity based on extracted features using MobileNetV2."
)

if __name__ == "__main__":
    iface.launch()