File size: 4,126 Bytes
4958ab0
 
 
 
 
 
 
 
 
 
 
fc8bd63
4958ab0
fc8bd63
4958ab0
 
 
fc8bd63
 
4958ab0
fc8bd63
 
 
4958ab0
 
 
 
fc8bd63
4958ab0
 
 
 
fc8bd63
4958ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc8bd63
4958ab0
 
 
 
 
 
 
 
 
 
 
 
fc8bd63
 
 
 
4958ab0
fc8bd63
4958ab0
 
 
 
 
 
 
 
 
 
 
 
 
fc8bd63
4958ab0
 
 
 
 
 
fc8bd63
4958ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc8bd63
4958ab0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from torchvision import transforms
import open_clip
from PIL import Image
import io
from typing import Optional

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# === 1) Cargar modelo CLIP (B/16) ===
clip_model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-16", pretrained="openai"
)
clip_model = clip_model.to(DEVICE)
clip_model.eval()
for p in clip_model.parameters():
    p.requires_grad = False

# === 2) Cargar embeddings hechos con B/16 ===
# (Aseg煤rate de que estos ficheros existen: los generaste como text_embeddings_modelos_b16.pt y text_embeddings_b16.pt)
model_ckpt = torch.load("text_embeddings_modelos_b16.pt", map_location=DEVICE)
model_labels = model_ckpt["labels"]
model_embeddings = model_ckpt["embeddings"].to(DEVICE)
model_embeddings /= model_embeddings.norm(dim=-1, keepdim=True)

version_ckpt = torch.load("text_embeddings_b16.pt", map_location=DEVICE)
version_labels = version_ckpt["labels"]
version_embeddings = version_ckpt["embeddings"].to(DEVICE)
version_embeddings /= version_embeddings.norm(dim=-1, keepdim=True)

# Transformaci贸n de imagen (usa la normalize del preprocess de B/16)
normalize = next(t for t in preprocess.transforms if isinstance(t, transforms.Normalize))
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=normalize.mean, std=normalize.std),
])

app = FastAPI()

def predict_top(text_feats, text_labels, image_tensor, topk=3):
    with torch.no_grad():
        image_features = clip_model.encode_image(image_tensor)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_feats.T).softmax(dim=-1)
        topk_result = torch.topk(similarity[0], k=topk)
    return [
        {"label": text_labels[idx], "confidence": round(conf.item() * 100, 2)}
        for conf, idx in zip(topk_result.values, topk_result.indices)
    ]

def process_image(image_bytes: bytes):
    img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(DEVICE)

    # Paso 1: predecir modelo
    top_model = predict_top(model_embeddings, model_labels, img_tensor, topk=1)[0]
    modelo_predecido = top_model["label"]
    confianza_modelo = top_model["confidence"]

    # Separar marca y modelo con cuidado (por si solo hay una palabra)
    partes = modelo_predecido.split(" ", 1)
    marca = partes[0]
    modelo = partes[1] if len(partes) > 1 else ""

    # Paso 2: filtrar versiones que empiecen con el label completo de modelo
    versiones_filtradas = [
        (label, idx) for idx, label in enumerate(version_labels)
        if label.startswith(modelo_predecido)
    ]

    if not versiones_filtradas:
        return {
            "marca": marca,
            "modelo": modelo,
            "confianza_modelo": confianza_modelo,
            "version": "No se encontraron versiones para este modelo"
        }

    # Paso 3: predecir versi贸n dentro de las versiones del modelo
    indices_versiones = [idx for _, idx in versiones_filtradas]
    versiones_labels = [label for label, _ in versiones_filtradas]
    versiones_embeds = version_embeddings[indices_versiones]

    top_version = predict_top(versiones_embeds, versiones_labels, img_tensor, topk=1)[0]
    version_predicha = (
        top_version["label"].replace(modelo_predecido + " ", "")
        if top_version["confidence"] >= 25
        else "Versi贸n no identificada con suficiente confianza"
    )

    return {
        "marca": marca,
        "modelo": modelo,
        "confianza_modelo": confianza_modelo,
        "version": version_predicha,
        "confianza_version": top_version["confidence"]
    }

@app.post("/predict/")
async def predict(front: UploadFile = File(...), back: Optional[UploadFile] = File(None)):
    front_bytes = await front.read()
    if back:
        _ = await back.read()  # de momento no se usa
    result = process_image(front_bytes)
    return JSONResponse(content=result)