|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print("device:", device) |
|
|
|
from torchvision import transforms |
|
import torchvision.models as models |
|
from torchvision.transforms import v2 |
|
|
|
import transformers |
|
from transformers import ViTImageProcessor |
|
from transformers import set_seed |
|
|
|
import datasets |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from accelerate import Accelerator, notebook_launcher |
|
|
|
import os |
|
import PIL |
|
from glob import glob |
|
import pandas as pd |
|
import numpy as np |
|
import gradio as gr |
|
|
|
SEED = 42 |
|
BATCH_SIZE = 1 |
|
MODEL_TRANSFORMER = 'google/vit-base-patch16-224' |
|
MODEL = "IMAGENET1K_V1" |
|
CLIP_SIZE = 224 |
|
data_path = 'employees' |
|
|
|
model_path_1 = 'models/Swin_tiny_celebA_imagenetWTS_loss0232.pt' |
|
model_path_2 = 'models/Swin_tiny_celebA_imagenetWTS_loss0209.pt' |
|
image_processor = ViTImageProcessor.from_pretrained(MODEL_TRANSFORMER, attn_implementation="sdpa", torch_dtype=torch.float16) |
|
|
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, image_paths, image_processor): |
|
self.image_paths = image_paths |
|
self.image_processor = image_processor |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
image_path = self.image_paths[idx] |
|
image_1 = PIL.Image.open(image_path).convert("RGB") |
|
image_1 = self.image_processor(image_1, return_tensors='pt')['pixel_values'][0] |
|
item = { |
|
'pixel_values': image_1, |
|
} |
|
return item |
|
|
|
|
|
class SwinTEmbedding(nn.Module): |
|
def __init__(self, model_name="DEFAULT"): |
|
super().__init__() |
|
self.model_name = model_name |
|
|
|
self.base_model = models.swin_t(weights=self.model_name) |
|
|
|
|
|
self.base_model_backbone = list(self.base_model.children())[0] |
|
|
|
def forward(self, x): |
|
x = self.base_model_backbone(x) |
|
x = torch.flatten(x, start_dim=1) |
|
return x |
|
|
|
|
|
def prod_function(reconstructed_model, prod_dl, webcam_img): |
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
set_seed(SEED) |
|
|
|
image_processor = ViTImageProcessor.from_pretrained(MODEL_TRANSFORMER, attn_implementation="sdpa", torch_dtype=torch.float16) |
|
webcam_img = image_processor(webcam_img, return_tensors='pt')['pixel_values'][0] |
|
webcam_img = torch.unsqueeze(webcam_img, 0) |
|
|
|
|
|
|
|
criterion = torch.nn.CosineSimilarity(dim=1, eps=1e-6) |
|
|
|
|
|
accelerated_model, acclerated_criterion, acclerated_prod_dl, acclerated_webcam_img = accelerator.prepare(reconstructed_model, criterion, prod_dl, webcam_img) |
|
|
|
|
|
accelerated_model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
webcam_emb = torch.flatten(accelerated_model(acclerated_webcam_img), start_dim=1) |
|
|
|
prod_predictions = [] |
|
for batch in acclerated_prod_dl: |
|
with torch.no_grad(): |
|
image_emb = torch.flatten(accelerated_model(batch['pixel_values']), start_dim=1) |
|
prod_predictions.append(acclerated_criterion(image_emb, webcam_emb)) |
|
|
|
return prod_predictions |
|
|
|
|
|
def face_recognition(webcan_img): |
|
|
|
image_paths = [] |
|
image_file = glob(os.path.join(data_path, '*.jpg')) |
|
image_paths.extend(image_file) |
|
|
|
prod_ds = CustomDataset(image_paths=image_paths, image_processor=image_processor) |
|
prod_dl = DataLoader(prod_ds, batch_size=BATCH_SIZE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
recon_model = SwinTEmbedding(model_name=MODEL) |
|
recon_model.load_state_dict(torch.load(model_path_2, weights_only=True, map_location=torch.device('cpu'))) |
|
|
|
|
|
prediction = prod_function(recon_model, prod_dl, webcan_img) |
|
|
|
|
|
prediction = torch.cat(prediction, 0) |
|
|
|
similarity_score = dict(zip(image_paths, prediction)) |
|
|
|
|
|
idx = prediction.argmax(-1) |
|
person_name = image_paths[idx].split('/')[-1].split('.')[0] |
|
|
|
return person_name, similarity_score |
|
|
|
|
|
def load_about_md(): |
|
with open("about.md", "r") as file: |
|
about_content = file.read() |
|
return about_content |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Face Recognition app 2.0") |
|
|
|
|
|
with gr.Tab("About the App"): |
|
gr.Markdown(load_about_md()) |
|
|
|
|
|
with gr.Tab("Face recognition"): |
|
with gr.Row(): |
|
with gr.Column(scale=0.9, variant="panel"): |
|
with gr.Row(height=350, variant="panel"): |
|
|
|
webcam_input = gr.Image(sources=["webcam"], type="pil", label="Face Capture") |
|
with gr.Row(variant="panel"): |
|
|
|
image_button = gr.Button("Submit") |
|
|
|
recognition_output = gr.Textbox(label="Face Recognised as:") |
|
with gr.Column(scale=1, variant="panel"): |
|
with gr.Row(): |
|
|
|
similarity_score_output = gr.Textbox(label="Face Similarity Score:") |
|
|
|
image_button.click(face_recognition, inputs=webcam_input, outputs=[recognition_output,similarity_score_output]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |