ruminasval commited on
Commit
82b5819
·
verified ·
1 Parent(s): 425e4d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -189
app.py CHANGED
@@ -1,24 +1,23 @@
1
- import os
2
- import cv2
3
  import gradio as gr
4
  import torch
5
- import numpy as np
6
- import mediapipe as mp
7
- from PIL import Image, ImageOps
8
  from transformers import SwinForImageClassification, AutoFeatureExtractor
 
 
 
 
 
9
 
10
- # =========================
11
- # Konfigurasi & Metadata
12
- # =========================
13
- FACE_SHAPE_DESCRIPTIONS = {
14
- "Heart": "dengan dahi lebar dan dagu yang runcing.",
15
  "Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.",
16
- "Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.",
17
- "Round": "dengan garis rahang melengkung dan pipi penuh.",
18
  "Square": "dengan rahang tegas dan dahi lebar."
19
  }
20
 
21
- GLASSES_IMAGES = {
 
22
  "Oval": "glasses/oval.jpg",
23
  "Round": "glasses/round.jpg",
24
  "Square": "glasses/square.jpg",
@@ -27,209 +26,141 @@ GLASSES_IMAGES = {
27
  "Pilot (Aviator)": "glasses/aviator.jpg"
28
  }
29
 
30
- # Pastikan folder 'glasses' ada & isi dummy kalau file hilang
31
- os.makedirs("glasses", exist_ok=True)
32
- for _, p in GLASSES_IMAGES.items():
33
- if not os.path.exists(p):
34
- Image.new("RGB", (300, 160), color="gray").save(p)
 
35
 
36
- ID2LABEL = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
37
- LABEL2ID = {v: k for k, v in ID2LABEL.items()}
38
 
39
- # =========================
40
- # Model & Device
41
- # =========================
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
- model_ckpt = "microsoft/swin-tiny-patch4-window7-224"
 
44
 
45
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
46
  model = SwinForImageClassification.from_pretrained(
47
- model_ckpt,
48
- label2id=LABEL2ID,
49
- id2label=ID2LABEL,
50
  ignore_mismatched_sizes=True
51
  )
52
 
53
- # Muat bobot terlatih jika ada
54
- if os.path.exists("LR-0001-adamW-32-64swin.pth"):
55
- state_dict = torch.load("LR-0001-adamW-32-64swin.pth", map_location=device)
56
  model.load_state_dict(state_dict, strict=False)
 
57
 
58
- model.to(device)
59
- model.eval()
60
 
61
- # =========================
62
- # Utils
63
- # =========================
64
- def recommend_glasses_tree(face_shape: str):
65
- s = face_shape.strip().lower()
66
- if s == "square":
67
  return ["Oval", "Round"]
68
- if s == "round":
69
  return ["Square", "Octagon", "Cat Eye"]
70
- if s == "oval":
71
  return ["Oval", "Pilot (Aviator)", "Cat Eye", "Round"]
72
- if s == "heart":
73
  return ["Oval", "Round", "Cat Eye", "Pilot (Aviator)"]
74
- if s == "oblong":
75
  return ["Square", "Pilot (Aviator)", "Cat Eye"]
76
- return []
77
-
78
- def _pil_to_bgr_ndarray(img: Image.Image) -> np.ndarray:
79
- """Pastikan image RGB, buang alpha/EXIF, lalu ke BGR (OpenCV)."""
80
- if img.mode not in ("RGB", "L"):
81
- # handle RGBA/CMYK/LA dll
82
- img = img.convert("RGB")
83
- elif img.mode == "L":
84
- img = ImageOps.colorize(img, black="black", white="white").convert("RGB")
85
-
86
- # strip EXIF untuk safety
87
- img_no_exif = Image.new(img.mode, img.size)
88
- img_no_exif.putdata(list(img.getdata()))
89
-
90
- arr = np.array(img_no_exif, dtype=np.uint8) # RGB
91
- return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
92
-
93
- def _safe_crop(img: np.ndarray, x1, y1, x2, y2):
94
- h, w = img.shape[:2]
95
- x1 = max(0, min(w-1, x1))
96
- y1 = max(0, min(h-1, y1))
97
- x2 = max(x1+1, min(w, x2))
98
- y2 = max(y1+1, min(h, y2))
99
- return img[y1:y2, x1:x2]
100
-
101
- def _center_crop_square(img: np.ndarray) -> np.ndarray:
102
- h, w = img.shape[:2]
103
- side = min(h, w)
104
- x1 = (w - side) // 2
105
- y1 = (h - side) // 2
106
- return img[y1:y1+side, x1:x1+side]
107
-
108
- def preprocess_image(pil_image: Image.Image) -> torch.Tensor:
109
- """Deteksi wajah (mediapipe). Jika gagal → fallback center crop. Resize 224, ke pixel_values tensor."""
110
- bgr = _pil_to_bgr_ndarray(pil_image)
111
-
112
- # mediapipe face detection (buat objek per-call biar thread-safe)
113
- with mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as fd:
114
- results = fd.process(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
115
-
116
- if results.detections:
117
- det = results.detections[0]
118
- bbox = det.location_data.relative_bounding_box
119
- h, w = bgr.shape[:2]
120
- x1 = int(bbox.xmin * w)
121
- y1 = int(bbox.ymin * h)
122
- x2 = int((bbox.xmin + bbox.width) * w)
123
- y2 = int((bbox.ymin + bbox.height) * h)
124
-
125
- face = _safe_crop(bgr, x1, y1, x2, y2)
126
- if face.size == 0 or face.shape[0] < 32 or face.shape[1] < 32:
127
- # kalau box aneh → fallback
128
- face = _center_crop_square(bgr)
129
  else:
130
- # fallback kalau wajah tidak terdeteksi
131
- face = _center_crop_square(bgr)
132
-
133
- face = cv2.resize(face, (224, 224), interpolation=cv2.INTER_AREA)
134
- face_rgb = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
135
- inputs = feature_extractor(images=face_rgb, return_tensors="pt")
136
- return inputs["pixel_values"].squeeze(0)
137
-
138
- def format_gallery_items(frames):
139
- items = []
140
- for frame in frames:
141
- path = GLASSES_IMAGES.get(frame)
142
- if path and os.path.exists(path):
143
- try:
144
- img = Image.open(path).convert("RGB")
145
- items.append((img, frame))
146
- except Exception as e:
147
- print(f"[WARN] gagal buka gambar frame {frame}: {e}")
148
- return items
149
-
150
- # =========================
151
- # Prediksi
152
- # =========================
153
- @torch.inference_mode()
154
- def predict(image: Image.Image):
155
- """
156
- Return:
157
- - bentuk_wajah_terdeteksi: str
158
- - penjelasan: str
159
- - rekomendasi_kacamata: list[(PIL.Image, caption)]
160
- """
 
161
  try:
162
- if image is None:
163
- return "Unknown", "Tidak ada gambar yang diunggah.", []
164
-
165
- pixel_values = preprocess_image(image).unsqueeze(0).to(device)
166
-
167
- # AMP untuk GPU
168
- if device.type == "cuda":
169
- with torch.cuda.amp.autocast():
170
- outputs = model(pixel_values)
171
- else:
172
- outputs = model(pixel_values)
173
-
174
- probs = torch.softmax(outputs.logits, dim=1)[0]
175
- pred_idx = int(torch.argmax(probs).item())
176
- pred_label = ID2LABEL[pred_idx]
177
- conf = float(probs[pred_idx].item()) * 100.0
178
-
179
- recs = recommend_glasses_tree(pred_label)
180
- gallery = format_gallery_items(recs)
181
-
182
- desc = FACE_SHAPE_DESCRIPTIONS.get(pred_label, "tidak dikenali")
183
- if recs:
184
- rec_text = ", ".join(recs)
185
- explanation = (
186
- f"Bentuk wajah kamu adalah {pred_label} ({conf:.2f}%). "
187
- f"Kamu memiliki bentuk wajah {desc} "
188
- f"Rekomendasi bentuk kacamata yang sesuai: {rec_text}."
189
- )
190
- else:
191
- explanation = (
192
- f"Bentuk wajah kamu adalah {pred_label} ({conf:.2f}%). "
193
- f"Belum ada rekomendasi frame untuk bentuk wajah ini."
194
- )
195
-
196
- return pred_label, explanation, gallery
197
 
198
  except Exception as e:
199
- # Jangan pernah crash: selalu balikin tuple sesuai schema
200
- return "Unknown", f"Terjadi kesalahan saat memproses gambar: {str(e)}", []
201
 
202
- # =========================
203
- # Gradio UI
204
- # =========================
205
  with gr.Blocks(theme=gr.themes.Soft()) as iface:
206
- gr.Markdown("# Program Rekomendasi Kacamata Berdasarkan Bentuk Wajah")
207
  gr.Markdown("Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai.")
208
 
209
  with gr.Row():
210
  with gr.Column():
211
- # penting: type='pil' supaya API melihat argumen bernama 'image'
212
- image_input = gr.Image(type="pil", label="Gambar Wajah")
213
- confirm_button = gr.Button("Konfirmasi", variant="primary")
214
  restart_button = gr.Button("Restart")
215
  with gr.Column():
216
- detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi", interactive=False)
217
- explanation_output = gr.Textbox(label="Penjelasan", lines=4, interactive=False)
218
- recommendation_gallery = gr.Gallery(label="Rekomendasi Kacamata", columns=3, show_label=True)
219
-
220
- confirm_button.click(
221
- predict,
222
- inputs=[image_input],
223
- outputs=[detected_shape, explanation_output, recommendation_gallery]
224
- )
225
- restart_button.click(
226
- lambda: (None, "", []),
227
- inputs=None,
228
- outputs=[image_input, detected_shape, explanation_output, recommendation_gallery]
229
- )
230
-
231
- gr.Markdown("**Sumber gambar kacamata**: Katalog dari glassdirect.co.uk")
232
 
233
  if __name__ == "__main__":
234
- # show_error=True biar pesan error server-side tampil jelas saat debug
235
- iface.queue().launch(show_error=True, server_name="0.0.0.0", server_port=7860)
 
 
 
1
  import gradio as gr
2
  import torch
 
 
 
3
  from transformers import SwinForImageClassification, AutoFeatureExtractor
4
+ import mediapipe as mp
5
+ import cv2
6
+ from PIL import Image
7
+ import numpy as np
8
+ import os
9
 
10
+ # --- Face shape descriptions
11
+ face_shape_descriptions = {
12
+ "Heart": "dengan dahi lebar dan dagu yang runcing.",
 
 
13
  "Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.",
14
+ "Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.",
15
+ "Round": "dengan garis rahang melengkung dan pipi penuh.",
16
  "Square": "dengan rahang tegas dan dahi lebar."
17
  }
18
 
19
+ # --- Glasses frame images
20
+ glasses_images = {
21
  "Oval": "glasses/oval.jpg",
22
  "Round": "glasses/round.jpg",
23
  "Square": "glasses/square.jpg",
 
26
  "Pilot (Aviator)": "glasses/aviator.jpg"
27
  }
28
 
29
+ # Ensure folder exists
30
+ if not os.path.exists("glasses"):
31
+ os.makedirs("glasses")
32
+ for _, path in glasses_images.items():
33
+ if not os.path.exists(path):
34
+ Image.new('RGB', (200, 100), color='gray').save(path)
35
 
36
+ id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
37
+ label2id = {v: k for k, v in id2label.items()}
38
 
39
+ # --- Load model
 
 
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
42
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
43
 
 
44
  model = SwinForImageClassification.from_pretrained(
45
+ model_checkpoint,
46
+ label2id=label2id,
47
+ id2label=id2label,
48
  ignore_mismatched_sizes=True
49
  )
50
 
51
+ # Load trained weights if available
52
+ if os.path.exists('LR-0001-adamW-32-64swin.pth'):
53
+ state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
54
  model.load_state_dict(state_dict, strict=False)
55
+ model.to(device).eval()
56
 
57
+ # --- Mediapipe
58
+ mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
59
 
60
+ # --- Decision tree rules
61
+ def recommend_glasses_tree(face_shape):
62
+ face_shape = face_shape.lower()
63
+ if face_shape == "square":
 
 
64
  return ["Oval", "Round"]
65
+ elif face_shape == "round":
66
  return ["Square", "Octagon", "Cat Eye"]
67
+ elif face_shape == "oval":
68
  return ["Oval", "Pilot (Aviator)", "Cat Eye", "Round"]
69
+ elif face_shape == "heart":
70
  return ["Oval", "Round", "Cat Eye", "Pilot (Aviator)"]
71
+ elif face_shape == "oblong":
72
  return ["Square", "Pilot (Aviator)", "Cat Eye"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  else:
74
+ return []
75
+
76
+ # --- Preprocess image
77
+ def preprocess_image(image):
78
+ img = np.array(image, dtype=np.uint8)
79
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
80
+
81
+ results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
82
+
83
+ if not results.detections:
84
+ return None # no face detected
85
+
86
+ detection = results.detections[0]
87
+ bbox = detection.location_data.relative_bounding_box
88
+ h, w, _ = img.shape
89
+ x1 = max(int(bbox.xmin * w), 0)
90
+ y1 = max(int(bbox.ymin * h), 0)
91
+ x2 = min(int((bbox.xmin + bbox.width) * w), w)
92
+ y2 = min(int((bbox.ymin + bbox.height) * h), h)
93
+
94
+ if x2 <= x1 or y2 <= y1:
95
+ return None
96
+
97
+ face_img = img[y1:y2, x1:x2]
98
+ face_img = cv2.resize(face_img, (224, 224))
99
+ face_img = cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)
100
+
101
+ inputs = feature_extractor(images=face_img, return_tensors="pt")
102
+ return inputs['pixel_values'].squeeze(0)
103
+
104
+ # --- Prediction
105
+ def predict(image):
106
  try:
107
+ inputs = preprocess_image(image)
108
+ if inputs is None:
109
+ return "Unknown", "⚠️ Wajah tidak terdeteksi. Silakan upload foto dengan wajah yang jelas.", []
110
+
111
+ inputs = inputs.unsqueeze(0).to(device)
112
+ with torch.no_grad():
113
+ outputs = model(inputs)
114
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
115
+ pred_idx = torch.argmax(probs, dim=1).item()
116
+ pred_label = id2label[pred_idx]
117
+ pred_prob = probs[0][pred_idx].item() * 100
118
+
119
+ # --- threshold confidence
120
+ if pred_prob < 70:
121
+ return "Unknown", f"⚠️ Prediksi tidak yakin (Confidence {pred_prob:.2f}%). Silakan gunakan foto wajah yang lebih jelas.", []
122
+
123
+ # --- glasses recommendation
124
+ frame_recommendations = recommend_glasses_tree(pred_label)
125
+ gallery_items = []
126
+ for frame in frame_recommendations:
127
+ frame_image_path = glasses_images.get(frame)
128
+ if frame_image_path and os.path.exists(frame_image_path):
129
+ gallery_items.append((Image.open(frame_image_path), frame))
130
+
131
+ description = face_shape_descriptions.get(pred_label, "tidak dikenali")
132
+ recommended_frames_text = ', '.join(frame_recommendations)
133
+
134
+ explanation = (
135
+ f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
136
+ f"Kamu memiliki bentuk wajah {description} "
137
+ f"Rekomendasi bentuk kacamata: {recommended_frames_text}."
138
+ )
139
+
140
+ return pred_label, explanation, gallery_items
 
141
 
142
  except Exception as e:
143
+ return "Error", f"Terjadi kesalahan: {str(e)}", []
 
144
 
145
+ # --- Gradio UI
 
 
146
  with gr.Blocks(theme=gr.themes.Soft()) as iface:
147
+ gr.Markdown("# 👓 Program Rekomendasi Kacamata Berdasarkan Bentuk Wajah")
148
  gr.Markdown("Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai.")
149
 
150
  with gr.Row():
151
  with gr.Column():
152
+ image_input = gr.Image(type="pil")
153
+ confirm_button = gr.Button("Konfirmasi")
 
154
  restart_button = gr.Button("Restart")
155
  with gr.Column():
156
+ detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi")
157
+ explanation_output = gr.Textbox(label="Penjelasan")
158
+ recommendation_gallery = gr.Gallery(label="Rekomendasi Kacamata", columns=3, show_label=False)
159
+
160
+ confirm_button.click(predict, inputs=image_input, outputs=[detected_shape, explanation_output, recommendation_gallery])
161
+ restart_button.click(lambda: (None, "", "", []), inputs=None, outputs=[image_input, detected_shape, explanation_output, recommendation_gallery])
162
+
163
+ gr.Markdown("**Sumber gambar kacamata**: Katalog dari [glassdirect.co.uk](https://www.glassdirect.co.uk)")
 
 
 
 
 
 
 
 
164
 
165
  if __name__ == "__main__":
166
+ iface.launch()