| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import mediapipe as mp | |
| import numpy as np | |
| import math | |
| import requests | |
| import gradio as gr | |
| model_url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth" | |
| model_path = "FER_static_ResNet50_AffectNet.pth" | |
| response = requests.get(model_url, stream=True) | |
| with open(model_path, 'wb') as file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| file.write(chunk) | |
| pth_model = torch.jit.load(model_path).to('cuda') | |
| pth_model.eval() | |
| DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'} | |
| mp_face_mesh = mp.solutions.face_mesh | |
| def pth_processing(fp): | |
| class PreprocessInput(torch.nn.Module): | |
| def init(self): | |
| super(PreprocessInput, self).init() | |
| def forward(self, x): | |
| x = x.to(torch.float32) | |
| x = torch.flip(x, dims=(0,)) | |
| x[0, :, :] -= 91.4953 | |
| x[1, :, :] -= 103.8827 | |
| x[2, :, :] -= 131.0912 | |
| return x | |
| def get_img_torch(img): | |
| ttransform = transforms.Compose([ | |
| transforms.PILToTensor(), | |
| PreprocessInput() | |
| ]) | |
| img = img.resize((224, 224), Image.Resampling.NEAREST) | |
| img = ttransform(img) | |
| img = torch.unsqueeze(img, 0).to('cuda') | |
| return img | |
| return get_img_torch(fp) | |
| def norm_coordinates(normalized_x, normalized_y, image_width, image_height): | |
| x_px = min(math.floor(normalized_x * image_width), image_width - 1) | |
| y_px = min(math.floor(normalized_y * image_height), image_height - 1) | |
| return x_px, y_px | |
| def get_box(fl, w, h): | |
| idx_to_coors = {} | |
| for idx, landmark in enumerate(fl.landmark): | |
| landmark_px = norm_coordinates(landmark.x, landmark.y, w, h) | |
| if landmark_px: | |
| idx_to_coors[idx] = landmark_px | |
| x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0]) | |
| y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1]) | |
| endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0]) | |
| endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1]) | |
| (startX, startY) = (max(0, x_min), max(0, y_min)) | |
| (endX, endY) = (min(w - 1, endX), min(h - 1, endY)) | |
| return startX, startY, endX, endY | |
| def predict(inp): | |
| inp = np.array(inp) | |
| h, w = inp.shape[:2] | |
| with mp_face_mesh.FaceMesh( | |
| max_num_faces=1, | |
| refine_landmarks=False, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5) as face_mesh: | |
| results = face_mesh.process(inp) | |
| if results.multi_face_landmarks: | |
| for fl in results.multi_face_landmarks: | |
| startX, startY, endX, endY = get_box(fl, w, h) | |
| cur_face = inp[startY:endY, startX: endX] | |
| cur_face_n = pth_processing(Image.fromarray(cur_face)) | |
| prediction = torch.nn.functional.softmax(pth_model(cur_face_n), dim=1).cpu().detach().numpy()[0] | |
| confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)} | |
| return cur_face, confidences | |
| def clear(): | |
| return ( | |
| gr.Image(value=None, type="pil"), | |
| gr.Image(value=None,scale=1, elem_classes="dl2"), | |
| gr.Label(value=None,num_top_classes=3, scale=1, elem_classes="dl3") | |
| ) | |
| style = """ | |
| div.dl1 div.upload-container { | |
| height: 350px; | |
| max-height: 350px; | |
| } | |
| div.dl2 { | |
| max-height: 200px; | |
| } | |
| div.dl2 img { | |
| max-height: 200px; | |
| } | |
| .submit { | |
| display: inline-block; | |
| padding: 10px 20px; | |
| font-size: 16px; | |
| font-weight: bold; | |
| text-align: center; | |
| text-decoration: none; | |
| cursor: pointer; | |
| border: var(--button-border-width) solid var(--button-primary-border-color); | |
| background: var(--button-primary-background-fill); | |
| color: var(--button-primary-text-color); | |
| border-radius: 8px; | |
| transition: all 0.3s ease; | |
| } | |
| .submit[disabled] { | |
| cursor: not-allowed; | |
| opacity: 0.6; | |
| } | |
| .submit:hover:not([disabled]) { | |
| border-color: var(--button-primary-border-color-hover); | |
| background: var(--button-primary-background-fill-hover); | |
| color: var(--button-primary-text-color-hover); | |
| } | |
| .submit:active:not([disabled]) { | |
| transform: scale(0.98); | |
| } | |
| """ | |
| with gr.Blocks(css=style) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=2, elem_classes="dl1"): | |
| input_image = gr.Image(type="pil") | |
| with gr.Row(): | |
| submit = gr.Button( | |
| value="Submit", interactive=True, scale=1, elem_classes="submit" | |
| ) | |
| clear_btn = gr.Button( | |
| value="Clear", interactive=True, scale=1 | |
| ) | |
| with gr.Column(scale=1, elem_classes="dl4"): | |
| output_image = gr.Image(scale=1, elem_classes="dl2") | |
| output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3") | |
| gr.Examples( | |
| ["images/fig7.jpg", "images/fig1.jpg", "images/fig2.jpg","images/fig3.jpg", | |
| "images/fig4.jpg", "images/fig5.jpg", "images/fig6.jpg"], | |
| [input_image], | |
| ) | |
| submit.click( | |
| fn=predict, | |
| inputs=[input_image], | |
| outputs=[ | |
| output_image, | |
| output_label | |
| ], | |
| queue=True, | |
| ) | |
| clear_btn.click( | |
| fn=clear, | |
| inputs=[], | |
| outputs=[ | |
| input_image, | |
| output_image, | |
| output_label, | |
| ], | |
| queue=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(api_open=False).launch(share=False) |