Pdro-ruiz's picture
Update app.py
f08cdaf verified
import gradio as gr
import torch
from transformers import (
Idefics2Processor, Idefics2ForConditionalGeneration,
Blip2Processor, Blip2ForConditionalGeneration
)
from PIL import Image
import time
import pandas as pd
import nltk
from nltk.translate.bleu_score import sentence_bleu
# Descargar 'punkt' si no está disponible
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")
# Configuración del dispositivo
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Usando dispositivo: {device}")
# Definición de modelos
models = {
"IDEFICS2": {
"model_id": "HuggingFaceM4/idefics2-8b",
"processor_class": Idefics2Processor,
"model_class": Idefics2ForConditionalGeneration,
"caption_prompt": "<image>Describe the image in detail"
},
"BLIP2": {
"model_id": "Salesforce/blip2-opt-2.7b",
"processor_class": Blip2Processor,
"model_class": Blip2ForConditionalGeneration,
"caption_prompt": "" # Prompt vacío para BLIP2
}
}
# Cargar modelos (pre-cargados para evitar retrasos)
model_instances = {}
for model_name, config in models.items():
processor = config["processor_class"].from_pretrained(config["model_id"])
model = config["model_class"].from_pretrained(config["model_id"]).to(device)
model_instances[model_name] = (processor, model)
# Preguntas VQA predefinidas
vqa_questions = [
"Are there people in the image?",
"Which color predominates in the image?"
]
# Referencia genérica para BLEU (puedes ajustar según necesidades)
reference_caption = ["An image with people and various objects"]
def infer(image, model_name, task, question=None):
if image is None:
return "Por favor, sube una imagen.", None, None, None, None, None
# Abrir y preparar la imagen
image = Image.open(image).convert("RGB")
if "BLIP2" in model_name:
image = image.resize((224, 224))
processor, model = model_instances[model_name]
start_time = time.time()
vram = torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0
if task == "captioning":
caption_prompt = models[model_name]["caption_prompt"]
caption_text = "" if "BLIP2" in model_name else caption_prompt
inputs = processor(images=image, text=caption_text, return_tensors="pt").to(device)
output_ids = model.generate(
**inputs,
max_new_tokens=50,
num_beams=5 if "BLIP2" in model_name else 1,
no_repeat_ngram_size=2 if "BLIP2" in model_name else 0
)
caption = processor.decode(output_ids[0], skip_special_tokens=True)
inference_time = time.time() - start_time
# Calcular BLEU (simplificado, usando referencia genérica)
bleu_score = sentence_bleu([reference_caption[0].split()], caption.split()) if caption else 0.0
return (caption, inference_time, None, None, vram, bleu_score)
elif task == "vqa" and question:
vqa_text = question if "BLIP2" in model_name else f"<image>Q: {question}"
inputs = processor(images=image, text=vqa_text, return_tensors="pt").to(device)
output_ids = model.generate(
**inputs,
max_new_tokens=10,
num_beams=5 if "BLIP2" in model_name else 1,
no_repeat_ngram_size=2 if "BLIP2" in model_name else 0
)
vqa_answer = processor.decode(output_ids[0], skip_special_tokens=True)
inference_time = time.time() - start_time
return (None, None, vqa_answer, inference_time, vram, None)
return "Selecciona una tarea válida y, para VQA, una pregunta.", None, None, None, None, None
# Interfaz Gradio
with gr.Blocks(title="MLLM Benchmark Demo") as demo:
gr.Markdown("# Benchmark para Modelos Multimodales (MLLMs)")
gr.Markdown("Sube una imagen, selecciona un modelo y una tarea, y obtén resultados de captioning o VQA.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="filepath", label="Subir Imagen")
model_dropdown = gr.Dropdown(choices=["IDEFICS2", "BLIP2"], label="Seleccionar Modelo", value="IDEFICS2")
task_dropdown = gr.Dropdown(choices=["captioning", "vqa"], label="Seleccionar Tarea", value="captioning")
question_input = gr.Textbox(label="Pregunta VQA (opcional, solo para VQA)", placeholder="Ej: Are there people in the image?")
submit_btn = gr.Button("Generar")
with gr.Column():
caption_output = gr.Textbox(label="Subtítulo Generado")
vqa_output = gr.Textbox(label="Respuesta VQA")
metrics_output = gr.Textbox(label="Métricas (Tiempo, VRAM, BLEU)")
submit_btn.click(
fn=infer,
inputs=[image_input, model_dropdown, task_dropdown, question_input],
outputs=[caption_output, gr.Number(label="Tiempo Captioning (s)"), vqa_output, gr.Number(label="Tiempo VQA (s)"), gr.Number(label="VRAM (GB)"), gr.Number(label="BLEU Score")]
)
gr.Markdown("### Notas")
gr.Markdown("""
- para mejroar la velocidad de inferencia, descarga en local y usar GPU avanzada.
- La métrica BLEU usa una referencia genérica y puede no reflejar la calidad real.
- Para más detalles, consulta el [repositorio del paper](https://huggingface.co/spaces/Pdro-ruiz/MLLM_Estado_del_Arte_Feb25/tree/main).
""")
if __name__ == "__main__":
demo.launch()