import streamlit as st import os import glob import numpy as np import cv2 from deepface import DeepFace from scipy.spatial.distance import cosine import matplotlib.pyplot as plt from PIL import Image import tempfile import tensorflow as tf st.set_page_config(page_title="Celebrity Lookalike Finder", layout="wide") # Styling st.markdown(""" """, unsafe_allow_html=True) # Title st.title("🌟 Celebrity Lookalike Finder") st.write("Upload your photo to find your celebrity doppelganger!") def detect_and_align_face(img_path): """Detect face and align it using OpenCV's face detector""" try: img = cv2.imread(img_path) if img is None: return None face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, 1.1, 4) if len(faces) == 0: return img x, y, w, h = faces[0] margin = 30 y = max(0, y - margin) h = min(img.shape[0] - y, h + 2*margin) x = max(0, x - margin) w = min(img.shape[1] - x, w + 2*margin) face = img[y:y+h, x:x+w] face = cv2.resize(face, (224, 224)) return face except Exception as e: st.error(f"Error in face detection: {str(e)}") return img def extract_features(img_path): """Extract features using DeepFace""" try: embedding = DeepFace.represent( img_path=img_path, model_name="VGG-Face", enforce_detection=False, detector_backend="opencv" ) if isinstance(embedding, list): embedding = embedding[0] if isinstance(embedding, dict): if 'embedding' in embedding: return np.array(embedding['embedding']) else: for value in embedding.values(): if isinstance(value, (list, np.ndarray)): return np.array(value).flatten() if isinstance(embedding, np.ndarray): return embedding.flatten() st.warning(f"Unexpected embedding type: {type(embedding)}") return None except Exception as e: st.error(f"Error in feature extraction: {str(e)}") return None @st.cache_data def build_celebrity_database(): """Build and cache celebrity database""" celebrity_paths = glob.glob('data/*.*') celebrity_features = [] celebrity_paths_list = [] progress_bar = st.progress(0) status_text = st.empty() for i, img_path in enumerate(celebrity_paths): status_text.text(f"Processing image {i+1}/{len(celebrity_paths)}") features = extract_features(img_path) if features is not None: celebrity_features.append(features) celebrity_paths_list.append(img_path) progress_bar.progress((i + 1) / len(celebrity_paths)) status_text.text("Database built successfully!") return celebrity_features, celebrity_paths_list def find_matches(user_features, celebrity_features, celebrity_paths, top_n=5): """Find celebrity matches""" similarities = [] for celeb_feature in celebrity_features: if user_features.shape != celeb_feature.shape: continue similarity = 1 - cosine(user_features, celeb_feature) similarities.append(similarity) if not similarities: st.warning("No valid comparisons could be made") return top_indices = np.argsort(similarities)[-top_n:][::-1] # Display results in columns cols = st.columns(top_n) for i, (idx, col) in enumerate(zip(top_indices, cols)): with col: celeb_img = Image.open(celebrity_paths[idx]) st.image(celeb_img, caption=f"Match {i+1}\nSimilarity: {similarities[idx]:.2%}") def main(): # Load celebrity database with st.spinner("Building celebrity database..."): celebrity_features, celebrity_paths = build_celebrity_database() # File uploader uploaded_file = st.file_uploader("Choose a photo", type=['jpg', 'jpeg', 'png']) if uploaded_file is not None: # Create columns for side-by-side display col1, col2 = st.columns(2) with col1: st.subheader("Your Photo") st.image(uploaded_file) # Process the uploaded image with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: tmp_file.write(uploaded_file.getvalue()) tmp_path = tmp_file.name # Extract features and find matches with st.spinner("Finding your celebrity matches..."): user_features = extract_features(tmp_path) if user_features is not None: with col2: st.subheader("Your Celebrity Matches") find_matches(user_features, celebrity_features, celebrity_paths) else: st.error("Could not process the uploaded image. Please try another photo.") # Clean up os.unlink(tmp_path) if __name__ == "__main__": main()