Faceshape / app.py
ruminasval's picture
Update app.py
e1060c2 verified
import gradio as gr
import torch
from transformers import SwinForImageClassification, AutoFeatureExtractor
import mediapipe as mp
import cv2
from PIL import Image
import numpy as np
import os
import time
# -----------------------------
# 1. Face shape descriptions
# -----------------------------
face_shape_descriptions = {
"Heart": "dengan dahi lebar dan dagu yang runcing.",
"Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.",
"Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.",
"Round": "dengan garis rahang melengkung dan pipi penuh.",
"Square": "dengan rahang tegas dan dahi lebar."
}
# -----------------------------
# 2. Glasses images (frames)
# -----------------------------
glasses_images = {
"Oval": "glasses/oval.jpg",
"Round": "glasses/round.jpg",
"Square": "glasses/square.jpg",
"Octagon": "glasses/octagon.jpg",
"Cat Eye": "glasses/cat eye.jpg",
"Pilot (Aviator)": "glasses/aviator.jpg"
}
if not os.path.exists("glasses"):
os.makedirs("glasses")
for _, path in glasses_images.items():
if not os.path.exists(path):
dummy_image = Image.new('RGB', (200, 100), color='gray')
dummy_image.save(path)
# -----------------------------
# 3. Label mappings
# -----------------------------
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}
# -----------------------------
# 4. Load Model
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
model = SwinForImageClassification.from_pretrained(
model_checkpoint,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
# Load trained weights (optional)
if os.path.exists('LR-0001-adamW-32-64swin.pth'):
state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
model.load_state_dict(state_dict, strict=False)
print("✅ Trained weights loaded")
else:
print("⚠️ Warning: 'LR-0001-adamW-32-64swin.pth' not found, using base pretrained weights")
model.to(device)
model.eval()
# -----------------------------
# 5. Mediapipe
# -----------------------------
mp_face_detection = mp.solutions.face_detection.FaceDetection(
model_selection=1, min_detection_confidence=0.5
)
# -----------------------------
# 6. Rule-based glasses recommendation
# -----------------------------
def recommend_glasses_tree(face_shape):
face_shape = face_shape.lower()
if face_shape == "square":
return ["Oval", "Round"]
elif face_shape == "round":
return ["Square", "Octagon", "Cat Eye"]
elif face_shape == "oval":
return ["Pilot (Aviator)", "Cat Eye", "Round"]
elif face_shape == "heart":
return ["Oval", "Round", "Cat Eye", "Pilot (Aviator)"]
elif face_shape == "oblong":
return ["Square", "Pilot (Aviator)", "Cat Eye"]
else:
return []
# -----------------------------
# 7. Preprocess image
# -----------------------------
def preprocess_image(image):
img = np.array(image, dtype=np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
if results.detections:
detection = results.detections[0]
bbox = detection.location_data.relative_bounding_box
h, w, _ = img.shape
x1 = max(int(bbox.xmin * w), 0)
y1 = max(int(bbox.ymin * h), 0)
x2 = min(int((bbox.xmin + bbox.width) * w), w)
y2 = min(int((bbox.ymin + bbox.height) * h), h)
if x2 > x1 and y2 > y1:
img = img[y1:y2, x1:x2]
else:
return None
else:
return None
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
inputs = feature_extractor(images=img, return_tensors="pt")
return inputs['pixel_values'].squeeze(0)
# -----------------------------
# 8. Prediction function
# -----------------------------
def predict(image):
start = time.perf_counter()
try:
inputs = preprocess_image(image)
if inputs is None:
elapsed_ms = (time.perf_counter() - start) * 1000
return "Unknown", "Wajah tidak terdeteksi.", [], f"{elapsed_ms:.2f} ms"
inputs = inputs.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
pred_label = id2label[pred_idx]
pred_prob = probs[0][pred_idx].item() * 100
frame_recommendations = recommend_glasses_tree(pred_label)
description = face_shape_descriptions.get(pred_label, "tidak dikenali")
gallery_items = []
for frame in frame_recommendations:
frame_image_path = glasses_images.get(frame)
if frame_image_path and os.path.exists(frame_image_path):
try:
frame_image = Image.open(frame_image_path)
gallery_items.append((frame_image, frame))
except Exception as e:
print(f"Error loading image for {frame}: {e}")
if frame_recommendations:
recommended_frames_text = ', '.join(frame_recommendations)
explanation = (
f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
f"Kamu memiliki bentuk wajah {description} "
f"Rekomendasi kacamata: {recommended_frames_text}."
)
else:
explanation = (
f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
f"Tidak ada rekomendasi frame."
)
elapsed_ms = (time.perf_counter() - start) * 1000
return pred_label, explanation, gallery_items, f"{elapsed_ms:.2f} ms"
except Exception as e:
elapsed_ms = (time.perf_counter() - start) * 1000
return "Error", f"Terjadi kesalahan: {str(e)}", [], f"{elapsed_ms:.2f} ms"
# -----------------------------
# 9. Gradio UI
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("# Program Rekomendasi Bentuk Kacamata")
gr.Markdown("Pastikan foto yang diunggah dapat terlihat jelas bagian wajah. Pastikan hanya menampilkan satu orang atau wajah untuk satu proses deteksi")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Foto Wajah")
confirm_button = gr.Button("Konfirmasi")
restart_button = gr.Button("Restart")
with gr.Column():
detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi")
explanation_output = gr.Textbox(label="Penjelasan")
recommendation_gallery = gr.Gallery(
label="Rekomendasi Kacamata", columns=3, show_label=False
)
time_output = gr.Textbox(label="Inference Time (ms)", interactive=False)
confirm_button.click(
predict,
inputs=image_input,
outputs=[detected_shape, explanation_output, recommendation_gallery, time_output]
)
restart_button.click(
lambda: (None, "", "", [], ""),
inputs=None,
outputs=[image_input, detected_shape, explanation_output, recommendation_gallery, time_output]
)
gr.Markdown("**Sumber gambar kacamata**: Katalog dari [glassdirect.co.uk](https://www.glassdirect.co.uk)")
if __name__ == "__main__":
iface.launch()