File size: 2,488 Bytes
721508a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import cv2
import gradio as gr
import torch
from torchvision.transforms import v2
from model_old import ResNet18
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ResNet18(1, 7)
model.load_state_dict(torch.load("./model.pth", map_location=device))
model.to(device)
model.eval()
face_cascade = cv2.CascadeClassifier("./haarcascade_frontalface_default.xml")
class_list = ["angry", "disgust", "fear", "happy", "neutral", "sad", "surprise"]
DATASET_MEAN = 0.5077385902404785
DATASET_STD = 0.255077600479126
preprocess = v2.Compose(
[
v2.Grayscale(),
v2.PILToTensor(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(DATASET_MEAN,), std=(DATASET_STD,)),
]
)
def get_probs(image):
inp = preprocess(torch.tensor(image).permute(2, 0, 1).unsqueeze(0))
inp = inp.to(device)
pred = model(inp).squeeze()
probs = torch.softmax(pred, 0).cpu()
return probs
def draw_labels(image, cords: tuple, label: str):
x, y, w, h = cords
cv2.rectangle(image, (x, y), (x + w, y + h), (255, 0, 0), 2)
image = cv2.putText(
image,
label,
(x, y),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
cv2.LINE_AA,
)
return image
def predict(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
for cords in faces:
x, y, w, h = cords
resized = cv2.resize(image[y : y + h, x : x + w], (48, 48), cv2.INTER_AREA)
probs = get_probs(resized)
label = class_list[probs.argmax(0).item()]
image = draw_labels(image, cords, label)
return image
webcam_interface = gr.Interface(
predict,
inputs=gr.Image(sources=['webcam'], streaming=True, label='Input webcam'),
outputs=gr.Image(label='Output video'),
live=True,
title='Webcam mode',
description='Created by Czarna Magia AI Student Club',
theme=gr.themes.Soft(),
)
img_interface = gr.Interface(
predict,
inputs=gr.Image(sources=['webcam', 'upload', 'clipboard'], label='Input image'),
outputs=gr.Image(label='Output image'),
title='Image upload mode',
description='Created by Czarna Magia AI Student Club',
theme=gr.themes.Soft(),
)
iface = gr.TabbedInterface(
interface_list=[img_interface, webcam_interface],
tab_names=['Image upload', 'Webcam'],
title='Face Expression Recognizer',
theme=gr.themes.Soft(),
)
iface.launch()
|