EF / app.py
Junior16's picture
Update app.py
c0eba86 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
import cv2
import numpy as np
from PIL import Image
import io
import base64
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch
app = FastAPI()
# Cargar el modelo de clasificaci贸n de edad y el extractor
model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
@app.post("/detect/")
async def detect_face(file: UploadFile = File(...)):
try:
# Leer y procesar la imagen cargada
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
img_np = np.array(image)
if img_np.shape[2] == 4:
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGRA2BGR)
# Cargar el clasificador Haar para detecci贸n de rostros
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
if len(faces) == 0:
raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.")
# Procesar cada rostro detectado
results = []
for (x, y, w, h) in faces:
# Extraer el rostro de la imagen
face_img = img_np[y:y+h, x:x+w]
pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
# Realizar la predicci贸n de edad
inputs = transforms(pil_face_img, return_tensors='pt')
output = model(**inputs)
proba = output.logits.softmax(1)
preds = proba.argmax(1)
# Asumimos que la predicci贸n est谩 representando un rango de edad (esto puede adaptarse m谩s tarde)
predicted_age_range = str(preds.item())
# Dibujar un rect谩ngulo alrededor del rostro y a帽adir la edad predicha
cv2.rectangle(img_np, (x, y), (x+w, y+h), (255, 0, 0), 2)
cv2.putText(img_np, f"Edad: {predicted_age_range}", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)
results.append({
"edad_predicha": predicted_age_range,
"coordenadas_rostro": (x, y, w, h)
})
# Convertir la imagen procesada a base64
result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
img_byte_arr = io.BytesIO()
result_image.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
return {
"message": "Rostros detectados y edad predicha",
"rostros": len(faces),
"resultados": results,
"imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8')
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))