File size: 2,488 Bytes
721508a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import cv2
import gradio as gr
import torch
from torchvision.transforms import v2

from model_old import ResNet18

device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet18(1, 7)
model.load_state_dict(torch.load("./model.pth", map_location=device))
model.to(device)
model.eval()

face_cascade = cv2.CascadeClassifier("./haarcascade_frontalface_default.xml")

class_list = ["angry", "disgust", "fear", "happy", "neutral", "sad", "surprise"]

DATASET_MEAN = 0.5077385902404785
DATASET_STD = 0.255077600479126

preprocess = v2.Compose(
    [
        v2.Grayscale(),
        v2.PILToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=(DATASET_MEAN,), std=(DATASET_STD,)),
    ]
)


def get_probs(image):
    inp = preprocess(torch.tensor(image).permute(2, 0, 1).unsqueeze(0))
    inp = inp.to(device)
    pred = model(inp).squeeze()
    probs = torch.softmax(pred, 0).cpu()
    return probs


def draw_labels(image, cords: tuple, label: str):
    x, y, w, h = cords
    cv2.rectangle(image, (x, y), (x + w, y + h), (255, 0, 0), 2)
    image = cv2.putText(
        image,
        label,
        (x, y),
        cv2.FONT_HERSHEY_SIMPLEX,
        1,
        (0, 255, 0),
        2,
        cv2.LINE_AA,
    )
    return image


def predict(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.1, 4)

    for cords in faces:
        x, y, w, h = cords
        resized = cv2.resize(image[y : y + h, x : x + w], (48, 48), cv2.INTER_AREA)
        probs = get_probs(resized)
        label = class_list[probs.argmax(0).item()]
        image = draw_labels(image, cords, label)

    return image


webcam_interface = gr.Interface(
    predict,
    inputs=gr.Image(sources=['webcam'], streaming=True, label='Input webcam'),
    outputs=gr.Image(label='Output video'),
    live=True,
    title='Webcam mode',
    description='Created by Czarna Magia AI Student Club',
    theme=gr.themes.Soft(),
)

img_interface = gr.Interface(
    predict,
    inputs=gr.Image(sources=['webcam', 'upload', 'clipboard'], label='Input image'),
    outputs=gr.Image(label='Output image'),
    title='Image upload mode',
    description='Created by Czarna Magia AI Student Club',
    theme=gr.themes.Soft(),
)

iface = gr.TabbedInterface(
    interface_list=[img_interface, webcam_interface],
    tab_names=['Image upload', 'Webcam'],
    title='Face Expression Recognizer',
    theme=gr.themes.Soft(),
)
iface.launch()