Spaces:
Running
Running
File size: 7,678 Bytes
ed33d5b 425e4d6 82b5819 32d14bc ec300a7 82b5819 ed33d5b 82b5819 ed33d5b ec300a7 82b5819 ed33d5b 82b5819 ec300a7 ed33d5b ec300a7 82b5819 ed33d5b ec300a7 ed33d5b 82b5819 ed33d5b 82b5819 ed33d5b ec300a7 82b5819 ed33d5b ec300a7 ed33d5b ec300a7 82b5819 ed33d5b 82b5819 ed33d5b 82b5819 e1060c2 82b5819 9ea4e99 82b5819 9ea4e99 ed33d5b 82b5819 ec300a7 82b5819 ec300a7 82b5819 ec300a7 82b5819 ec300a7 82b5819 ec300a7 425e4d6 82b5819 32d14bc 1b31555 82b5819 ec300a7 82b5819 ec300a7 82b5819 ec300a7 ed33d5b ec300a7 32d14bc ed33d5b ec300a7 cbc8a5e 1b31555 3f264c3 ec300a7 82b5819 e925709 ec300a7 3f264c3 82b5819 ec300a7 82b5819 93bb675 9374fca 82b5819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import gradio as gr
import torch
from transformers import SwinForImageClassification, AutoFeatureExtractor
import mediapipe as mp
import cv2
from PIL import Image
import numpy as np
import os
import time
# -----------------------------
# 1. Face shape descriptions
# -----------------------------
face_shape_descriptions = {
"Heart": "dengan dahi lebar dan dagu yang runcing.",
"Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.",
"Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.",
"Round": "dengan garis rahang melengkung dan pipi penuh.",
"Square": "dengan rahang tegas dan dahi lebar."
}
# -----------------------------
# 2. Glasses images (frames)
# -----------------------------
glasses_images = {
"Oval": "glasses/oval.jpg",
"Round": "glasses/round.jpg",
"Square": "glasses/square.jpg",
"Octagon": "glasses/octagon.jpg",
"Cat Eye": "glasses/cat eye.jpg",
"Pilot (Aviator)": "glasses/aviator.jpg"
}
if not os.path.exists("glasses"):
os.makedirs("glasses")
for _, path in glasses_images.items():
if not os.path.exists(path):
dummy_image = Image.new('RGB', (200, 100), color='gray')
dummy_image.save(path)
# -----------------------------
# 3. Label mappings
# -----------------------------
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}
# -----------------------------
# 4. Load Model
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
model = SwinForImageClassification.from_pretrained(
model_checkpoint,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
# Load trained weights (optional)
if os.path.exists('LR-0001-adamW-32-64swin.pth'):
state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
model.load_state_dict(state_dict, strict=False)
print("✅ Trained weights loaded")
else:
print("⚠️ Warning: 'LR-0001-adamW-32-64swin.pth' not found, using base pretrained weights")
model.to(device)
model.eval()
# -----------------------------
# 5. Mediapipe
# -----------------------------
mp_face_detection = mp.solutions.face_detection.FaceDetection(
model_selection=1, min_detection_confidence=0.5
)
# -----------------------------
# 6. Rule-based glasses recommendation
# -----------------------------
def recommend_glasses_tree(face_shape):
face_shape = face_shape.lower()
if face_shape == "square":
return ["Oval", "Round"]
elif face_shape == "round":
return ["Square", "Octagon", "Cat Eye"]
elif face_shape == "oval":
return ["Pilot (Aviator)", "Cat Eye", "Round"]
elif face_shape == "heart":
return ["Oval", "Round", "Cat Eye", "Pilot (Aviator)"]
elif face_shape == "oblong":
return ["Square", "Pilot (Aviator)", "Cat Eye"]
else:
return []
# -----------------------------
# 7. Preprocess image
# -----------------------------
def preprocess_image(image):
img = np.array(image, dtype=np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
if results.detections:
detection = results.detections[0]
bbox = detection.location_data.relative_bounding_box
h, w, _ = img.shape
x1 = max(int(bbox.xmin * w), 0)
y1 = max(int(bbox.ymin * h), 0)
x2 = min(int((bbox.xmin + bbox.width) * w), w)
y2 = min(int((bbox.ymin + bbox.height) * h), h)
if x2 > x1 and y2 > y1:
img = img[y1:y2, x1:x2]
else:
return None
else:
return None
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
inputs = feature_extractor(images=img, return_tensors="pt")
return inputs['pixel_values'].squeeze(0)
# -----------------------------
# 8. Prediction function
# -----------------------------
def predict(image):
start = time.perf_counter()
try:
inputs = preprocess_image(image)
if inputs is None:
elapsed_ms = (time.perf_counter() - start) * 1000
return "Unknown", "Wajah tidak terdeteksi.", [], f"{elapsed_ms:.2f} ms"
inputs = inputs.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
pred_label = id2label[pred_idx]
pred_prob = probs[0][pred_idx].item() * 100
frame_recommendations = recommend_glasses_tree(pred_label)
description = face_shape_descriptions.get(pred_label, "tidak dikenali")
gallery_items = []
for frame in frame_recommendations:
frame_image_path = glasses_images.get(frame)
if frame_image_path and os.path.exists(frame_image_path):
try:
frame_image = Image.open(frame_image_path)
gallery_items.append((frame_image, frame))
except Exception as e:
print(f"Error loading image for {frame}: {e}")
if frame_recommendations:
recommended_frames_text = ', '.join(frame_recommendations)
explanation = (
f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
f"Kamu memiliki bentuk wajah {description} "
f"Rekomendasi kacamata: {recommended_frames_text}."
)
else:
explanation = (
f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
f"Tidak ada rekomendasi frame."
)
elapsed_ms = (time.perf_counter() - start) * 1000
return pred_label, explanation, gallery_items, f"{elapsed_ms:.2f} ms"
except Exception as e:
elapsed_ms = (time.perf_counter() - start) * 1000
return "Error", f"Terjadi kesalahan: {str(e)}", [], f"{elapsed_ms:.2f} ms"
# -----------------------------
# 9. Gradio UI
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("# Program Rekomendasi Bentuk Kacamata")
gr.Markdown("Pastikan foto yang diunggah dapat terlihat jelas bagian wajah. Pastikan hanya menampilkan satu orang atau wajah untuk satu proses deteksi")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Foto Wajah")
confirm_button = gr.Button("Konfirmasi")
restart_button = gr.Button("Restart")
with gr.Column():
detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi")
explanation_output = gr.Textbox(label="Penjelasan")
recommendation_gallery = gr.Gallery(
label="Rekomendasi Kacamata", columns=3, show_label=False
)
time_output = gr.Textbox(label="Inference Time (ms)", interactive=False)
confirm_button.click(
predict,
inputs=image_input,
outputs=[detected_shape, explanation_output, recommendation_gallery, time_output]
)
restart_button.click(
lambda: (None, "", "", [], ""),
inputs=None,
outputs=[image_input, detected_shape, explanation_output, recommendation_gallery, time_output]
)
gr.Markdown("**Sumber gambar kacamata**: Katalog dari [glassdirect.co.uk](https://www.glassdirect.co.uk)")
if __name__ == "__main__":
iface.launch()
|