|
import streamlit as st |
|
import cv2 |
|
import numpy as np |
|
from datetime import datetime |
|
import torch |
|
from facenet_pytorch import MTCNN, InceptionResnetV1 |
|
from keras.models import load_model |
|
from PIL import Image |
|
import sqlite3 |
|
import os |
|
import tempfile |
|
|
|
|
|
DB_NAME = "emotion_detection.db" |
|
|
|
|
|
def initialize_database(): |
|
conn = sqlite3.connect(DB_NAME) |
|
cursor = conn.cursor() |
|
cursor.execute(""" |
|
CREATE TABLE IF NOT EXISTS face_data ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
name TEXT NOT NULL, |
|
emotion TEXT NOT NULL, |
|
timestamp TEXT NOT NULL |
|
) |
|
""") |
|
conn.commit() |
|
conn.close() |
|
|
|
initialize_database() |
|
|
|
|
|
st.markdown("<h1 style='text-align: center;'>Emotion Detection with Face Recognition</h1>", unsafe_allow_html=True) |
|
st.markdown("<h3 style='text-align: center;'>angry, fear, happy, neutral, sad, surprise</h3>", unsafe_allow_html=True) |
|
|
|
|
|
@st.cache_resource |
|
def load_emotion_model(): |
|
model = load_model('CNN_Model_acc_75.h5') |
|
return model |
|
|
|
emotion_model = load_emotion_model() |
|
emotion_labels = ['angry', 'fear', 'happy', 'neutral', 'sad', 'surprise'] |
|
|
|
|
|
facenet = InceptionResnetV1(pretrained='vggface2').eval() |
|
mtcnn = MTCNN() |
|
|
|
|
|
known_faces = [] |
|
known_names = [] |
|
|
|
def load_known_faces(): |
|
folder_path = "known_faces" |
|
for image_name in os.listdir(folder_path): |
|
if image_name.endswith(('.jpg', '.jpeg', '.png')): |
|
image_path = os.path.join(folder_path, image_name) |
|
image = Image.open(image_path).convert("RGB") |
|
face, _ = mtcnn.detect(image) |
|
|
|
if face is not None: |
|
face_box = face[0].astype(int) |
|
cropped_face = image.crop((face_box[0], face_box[1], face_box[2], face_box[3])) |
|
cropped_face = cropped_face.resize((160, 160)) |
|
face_tensor = np.array(cropped_face).transpose(2, 0, 1) / 255.0 |
|
face_tensor = torch.tensor(face_tensor, dtype=torch.float32).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
embedding = facenet(face_tensor).numpy() |
|
|
|
known_faces.append(embedding) |
|
known_names.append(image_name.split('.')[0]) |
|
|
|
load_known_faces() |
|
|
|
def recognize_face(embedding): |
|
min_distance = float('inf') |
|
name = "Unknown" |
|
for idx, known_embedding in enumerate(known_faces): |
|
distance = np.linalg.norm(known_embedding - embedding) |
|
if distance < min_distance and distance < 0.6: |
|
min_distance = distance |
|
name = known_names[idx] |
|
return name |
|
|
|
def process_frame(frame): |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
faces, _ = mtcnn.detect(frame_rgb) |
|
result_text = "" |
|
|
|
if faces is not None: |
|
for face_box in faces: |
|
x1, y1, x2, y2 = map(int, face_box) |
|
cropped_face = frame_rgb[y1:y2, x1:x2] |
|
resized_face = cv2.resize(cropped_face, (48, 48)) |
|
face_normalized = resized_face / 255.0 |
|
face_array = np.expand_dims(face_normalized, axis=0) |
|
|
|
|
|
predictions = emotion_model.predict(face_array) |
|
emotion = emotion_labels[np.argmax(predictions[0])] |
|
|
|
|
|
cropped_face_for_recognition = cv2.resize(cropped_face, (160, 160)) |
|
face_tensor = np.array(cropped_face_for_recognition).transpose(2, 0, 1) / 255.0 |
|
face_tensor = torch.tensor(face_tensor, dtype=torch.float32).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
face_embedding = facenet(face_tensor).numpy() |
|
|
|
name = recognize_face(face_embedding) |
|
|
|
|
|
if name != "Unknown": |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
conn = sqlite3.connect(DB_NAME) |
|
cursor = conn.cursor() |
|
cursor.execute(""" |
|
INSERT INTO face_data (name, emotion, timestamp) |
|
VALUES (?, ?, ?) |
|
""", (name, emotion, timestamp)) |
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
result_text = f"{name} is feeling {emotion}" |
|
cv2.putText(frame, result_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |
|
else: |
|
result_text = "No face detected!" |
|
|
|
return frame, result_text |
|
|
|
def video_feed(video_source): |
|
frame_placeholder = st.empty() |
|
text_placeholder = st.empty() |
|
|
|
while True: |
|
ret, frame = video_source.read() |
|
if not ret: |
|
break |
|
frame, result_text = process_frame(frame) |
|
frame_placeholder.image(frame, channels="BGR", use_column_width=True) |
|
text_placeholder.markdown(f"<h3 style='text-align: center;'>{result_text}</h3>", unsafe_allow_html=True) |
|
|
|
|
|
upload_choice = st.sidebar.radio("Choose Input Source", ["Upload Image", "Upload Video", "Camera"]) |
|
|
|
if upload_choice == "Camera": |
|
image = st.camera_input("Take a picture") |
|
if image: |
|
frame = np.array(Image.open(image)) |
|
frame, result_text = process_frame(frame) |
|
st.image(frame, caption='Processed Image', use_column_width=True) |
|
st.markdown(f"<h3 style='text-align: center;'>{result_text}</h3>", unsafe_allow_html=True) |
|
|
|
elif upload_choice == "Upload Image": |
|
uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) |
|
if uploaded_image: |
|
image = Image.open(uploaded_image) |
|
frame = np.array(image) |
|
frame, result_text = process_frame(frame) |
|
st.image(frame, caption='Processed Image', use_column_width=True) |
|
st.markdown(f"<h3 style='text-align: center;'>{result_text}</h3>", unsafe_allow_html=True) |
|
|
|
elif upload_choice == "Upload Video": |
|
uploaded_video = st.file_uploader("Upload Video", type=["mp4", "mov", "avi"]) |
|
if uploaded_video: |
|
with tempfile.NamedTemporaryFile(delete=False) as tfile: |
|
tfile.write(uploaded_video.read()) |
|
video_source = cv2.VideoCapture(tfile.name) |
|
video_feed(video_source) |
|
|
|
|
|
st.markdown("### Recent Records") |
|
conn = sqlite3.connect(DB_NAME) |
|
cursor = conn.cursor() |
|
cursor.execute("SELECT name, emotion, timestamp FROM face_data ORDER BY timestamp DESC LIMIT 5") |
|
records = cursor.fetchall() |
|
conn.close() |
|
|
|
for record in records: |
|
col1, col2, col3 = st.columns(3) |
|
col1.write(f"**Name**: {record[0]}") |
|
col2.write(f"**Emotion**: {record[1]}") |
|
col3.write(f"**Timestamp**: {record[2]}") |
|
|
|
|
|
|