File size: 4,126 Bytes
4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 fc8bd63 4958ab0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from torchvision import transforms
import open_clip
from PIL import Image
import io
from typing import Optional
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# === 1) Cargar modelo CLIP (B/16) ===
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-B-16", pretrained="openai"
)
clip_model = clip_model.to(DEVICE)
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
# === 2) Cargar embeddings hechos con B/16 ===
# (Aseg煤rate de que estos ficheros existen: los generaste como text_embeddings_modelos_b16.pt y text_embeddings_b16.pt)
model_ckpt = torch.load("text_embeddings_modelos_b16.pt", map_location=DEVICE)
model_labels = model_ckpt["labels"]
model_embeddings = model_ckpt["embeddings"].to(DEVICE)
model_embeddings /= model_embeddings.norm(dim=-1, keepdim=True)
version_ckpt = torch.load("text_embeddings_b16.pt", map_location=DEVICE)
version_labels = version_ckpt["labels"]
version_embeddings = version_ckpt["embeddings"].to(DEVICE)
version_embeddings /= version_embeddings.norm(dim=-1, keepdim=True)
# Transformaci贸n de imagen (usa la normalize del preprocess de B/16)
normalize = next(t for t in preprocess.transforms if isinstance(t, transforms.Normalize))
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=normalize.mean, std=normalize.std),
])
app = FastAPI()
def predict_top(text_feats, text_labels, image_tensor, topk=3):
with torch.no_grad():
image_features = clip_model.encode_image(image_tensor)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_feats.T).softmax(dim=-1)
topk_result = torch.topk(similarity[0], k=topk)
return [
{"label": text_labels[idx], "confidence": round(conf.item() * 100, 2)}
for conf, idx in zip(topk_result.values, topk_result.indices)
]
def process_image(image_bytes: bytes):
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(DEVICE)
# Paso 1: predecir modelo
top_model = predict_top(model_embeddings, model_labels, img_tensor, topk=1)[0]
modelo_predecido = top_model["label"]
confianza_modelo = top_model["confidence"]
# Separar marca y modelo con cuidado (por si solo hay una palabra)
partes = modelo_predecido.split(" ", 1)
marca = partes[0]
modelo = partes[1] if len(partes) > 1 else ""
# Paso 2: filtrar versiones que empiecen con el label completo de modelo
versiones_filtradas = [
(label, idx) for idx, label in enumerate(version_labels)
if label.startswith(modelo_predecido)
]
if not versiones_filtradas:
return {
"marca": marca,
"modelo": modelo,
"confianza_modelo": confianza_modelo,
"version": "No se encontraron versiones para este modelo"
}
# Paso 3: predecir versi贸n dentro de las versiones del modelo
indices_versiones = [idx for _, idx in versiones_filtradas]
versiones_labels = [label for label, _ in versiones_filtradas]
versiones_embeds = version_embeddings[indices_versiones]
top_version = predict_top(versiones_embeds, versiones_labels, img_tensor, topk=1)[0]
version_predicha = (
top_version["label"].replace(modelo_predecido + " ", "")
if top_version["confidence"] >= 25
else "Versi贸n no identificada con suficiente confianza"
)
return {
"marca": marca,
"modelo": modelo,
"confianza_modelo": confianza_modelo,
"version": version_predicha,
"confianza_version": top_version["confidence"]
}
@app.post("/predict/")
async def predict(front: UploadFile = File(...), back: Optional[UploadFile] = File(None)):
front_bytes = await front.read()
if back:
_ = await back.read() # de momento no se usa
result = process_image(front_bytes)
return JSONResponse(content=result)
|