Spaces:
Runtime error
Runtime error
File size: 5,464 Bytes
ddb02da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from flask import Flask, request, jsonify, send_file
from PIL import Image
import base64
import spaces
from loadimg import load_img
from io import BytesIO
import numpy as np
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover
import threading
import logging
import uuid
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
# Load the model lazily
model = None
detector = None
def load_model():
global model, detector
path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
options = ort.SessionOptions()
options.intra_op_num_threads = 8
options.inter_op_num_threads = 8
# Ensure the session is created with CPUExecutionProvider only
session = ort.InferenceSession(
path, sess_options=options, providers=["CPUExecutionProvider"]
)
model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
detector = model
logging.info("Model loaded successfully.")
torch.set_float32_matmul_precision(["high", "highest"][0])
# Load BiRefNet for segmentation, set to CPU
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu") # Move the model to CPU
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
# Function to decode a base64 image to PIL.Image.Image
def decode_image_from_base64(image_data):
image_data = base64.b64decode(image_data)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
# Function to encode a PIL image to base64
def encode_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA
return base64.b64encode(buffered.getvalue()).decode('utf-8')
# Remove background using BiRefNet on CPU
def rm_background(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to("cpu") # Ensure CPU usage
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
# Remove background with the transparent background remover
def remove_background(image):
remover = Remover()
if isinstance(image, Image.Image):
output = remover.process(image)
elif isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
output = remover.process(image_pil)
else:
raise TypeError("Unsupported image type")
return output
def detect_and_segment_persons(image, clothes):
img = np.array(image)
img = img[:, :, ::-1] # RGB -> BGR
if detector is None:
load_model() # Ensure the model is loaded
bboxes, kpss = detector.detect(img)
if bboxes.shape[0] == 0:
return [save_image(rm_background(image))]
height, width, _ = img.shape
bboxes = np.round(bboxes[:, :4]).astype(int)
bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
all_segmented_images = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1, y1, x2, y2 = bbox
person_img = img[y1:y2, x1:x2]
pil_img = Image.fromarray(person_img[:, :, ::-1])
img_rm_background = rm_background(pil_img)
segmented_result = segment_clothing(img_rm_background, clothes)
image_paths = [save_image(img) for img in segmented_result]
print(image_paths)
all_segmented_images.extend(image_paths)
return all_segmented_images
@app.route('/', methods=['GET'])
def welcome():
return "Welcome to Clothing Segmentation API"
@app.route('/api/detect', methods=['POST'])
def detect():
try:
data = request.json
image_base64 = data['image']
image = decode_image_from_base64(image_base64)
clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
result = detect_and_segment_persons(image, clothes)
return jsonify({'images': result})
except Exception as e:
logging.error(f"Error occurred: {e}")
return jsonify({'error': str(e)}), 500
# Route to retrieve the generated image
@app.route('/api/get_image/<image_id>', methods=['GET'])
def get_image(image_id):
# Construct the full image path
image_path = image_id # Ensure the file name matches the one used during saving
# Return the image
try:
return send_file(image_path, mimetype='image/png')
except FileNotFoundError:
return jsonify({'error': 'Image not found'}), 404
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=7860)
|