coder
primer_commit
9cc66e2
raw
history blame
1.17 kB
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from io import BytesIO
import requests
class Generador():
def __init__(self, configuraciones):
self.modelo = configuraciones.get('model')
self.tokenizer = configuraciones.get('tokenizer')
def generar_prediccion(self, imagen_bytes):
# @title **Ejemplo práctico**
prediccion = None
try:
# Inicializamos los procesadores y el modelo
procesador = ViTImageProcessor.from_pretrained(self.tokenizer)
modelo = ViTForImageClassification.from_pretrained(self.modelo)
# Procesamos nuestra imagen
inputs = procesador(images=imagen_bytes, return_tensors="pt")
outputs = modelo(**inputs)
logits = outputs.logits
# Obtenemos las predicciones
predicted_class_idx = logits.argmax(-1).item()
prediccion = modelo.config.id2label[predicted_class_idx]
except Exception as error:
print(f"No es Chems\n{error}")
prediccion = error
finally:
self.prediccion = str(prediccion)