Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import base64 | |
| import requests | |
| import json | |
| import time | |
| from transformers import DetrForObjectDetection, DetrImageProcessor | |
| import torch | |
| # Function to detect face and neck for placing jewelry | |
| def detect_face_and_neck(image): | |
| model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50') | |
| processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50') | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([image.shape[:2]]) | |
| results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0] | |
| neck_box = None | |
| face_box = None | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| if score > 0.7: | |
| if label == 1: # Person | |
| neck_box = box | |
| elif label == 2: # Face | |
| face_box = box | |
| return face_box, neck_box | |
| # Function to overlay jewelry on the detected regions | |
| def place_jewelry(image, jewelry_image, jewelry_type, position): | |
| x, y, w, h = position | |
| resized_jewelry = cv2.resize(jewelry_image, (w, h)) | |
| image[y:y+h, x:x+w] = resized_jewelry | |
| return image | |
| # Try-on function for jewelry | |
| def tryon_jewelry(person_img, jewelry_img, jewelry_type): | |
| face_box, neck_box = detect_face_and_neck(person_img) | |
| if jewelry_type == "Necklace" and neck_box is not None: | |
| result_img = place_jewelry(person_img, jewelry_img, "Necklace", neck_box) | |
| elif jewelry_type == "Earrings" and face_box is not None: | |
| result_img = place_jewelry(person_img, jewelry_img, "Earrings", face_box) | |
| else: | |
| result_img = person_img # Return original image if no detection | |
| return result_img | |
| # Gradio interface setup | |
| css = """ | |
| #col-left, #col-mid, #col-right { | |
| margin: 0 auto; | |
| max-width: 430px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as JewelryTryon: | |
| gr.HTML("<h1>Virtual Jewelry Try-On</h1>") | |
| with gr.Row(): | |
| with gr.Column(elem_id="col-left"): | |
| imgs = gr.Image(label="Person image", sources='upload', type="numpy") | |
| with gr.Column(elem_id="col-mid"): | |
| garm_img = gr.Image(label="Jewelry image", sources='upload', type="numpy") | |
| with gr.Column(elem_id="col-right"): | |
| jewelry_type = gr.Dropdown(label="Jewelry Type", choices=['Necklace', 'Earrings', 'Ring'], value="Necklace") | |
| image_out = gr.Image(label="Result", show_share_button=False) | |
| run_button = gr.Button(value="Run") | |
| run_button.click(fn=tryon_jewelry, inputs=[imgs, garm_img, jewelry_type], outputs=image_out) | |
| # Launch Gradio app | |
| JewelryTryon.queue(api_open=False).launch(show_api=False) | |