Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| from huggingface_hub import list_models | |
| from typing import List | |
| import torch | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| from PIL import Image | |
| import json | |
| import re | |
| import logging | |
| from datasets import load_dataset | |
| import os | |
| import numpy as np | |
| from datetime import datetime | |
| # Importar utils y save_img si no están ya importados | |
| import utils | |
| # Logging configuration | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Paths to the static image and GIF | |
| README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png") | |
| GIF_PATH = os.path.join("figs", "demo-samples.gif") | |
| # Global variables for Donut model, processor, and dataset | |
| dataset = None | |
| def load_merit_dataset(): | |
| global dataset | |
| if dataset is None: | |
| dataset = load_dataset( | |
| "de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8 | |
| ) | |
| return dataset | |
| def get_image_from_dataset(index): | |
| global dataset | |
| if dataset is None: | |
| dataset = load_merit_dataset() | |
| image_data = dataset[int(index)]["image"] | |
| return image_data | |
| def get_collection_models(tag: str) -> List[str]: | |
| """Get a list of models from a specific Hugging Face collection.""" | |
| models = list_models(author="de-Rodrigo") | |
| return [model.modelId for model in models if tag in model.tags] | |
| def initialize_donut(): | |
| try: | |
| donut_model = VisionEncoderDecoderModel.from_pretrained( | |
| "de-Rodrigo/donut-merit" | |
| ) | |
| donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit") | |
| donut_model = donut_model.to("cuda") | |
| logger.info("Donut model loaded successfully on GPU") | |
| return donut_model, donut_processor | |
| except Exception as e: | |
| logger.error(f"Error loading Donut model: {str(e)}") | |
| raise | |
| def compute_saliency(outputs, pixels, donut_p, image): | |
| token_logits = torch.stack(outputs.scores, dim=1) | |
| token_probs = torch.softmax(token_logits, dim=-1) | |
| token_texts = [] | |
| saliency_images = [] | |
| for token_index in range(len(token_probs[0])): | |
| target_token_prob = token_probs[ | |
| 0, token_index, outputs.sequences[0, token_index] | |
| ] | |
| if pixels.grad is not None: | |
| pixels.grad.zero_() | |
| target_token_prob.backward(retain_graph=True) | |
| saliency = pixels.grad.data.abs().squeeze().mean(dim=0) | |
| token_id = outputs.sequences[0][token_index].item() | |
| token_text = donut_p.tokenizer.decode([token_id]) | |
| logger.info(f"Considered sequence token: {token_text}") | |
| safe_token_text = re.sub(r'[<>:"/\\|?*]', "_", token_text) | |
| current_datetime = datetime.now().strftime("%Y%m%d%H%M%S") | |
| unique_safe_token_text = f"{safe_token_text}_{current_datetime}" | |
| file_name = f"saliency_{unique_safe_token_text}.png" | |
| saliency = utils.convert_tensor_to_rgba_image(saliency) | |
| # Merge saliency image twice | |
| saliency = utils.add_transparent_image(np.array(image), saliency) | |
| saliency = utils.convert_rgb_to_rgba_image(saliency) | |
| saliency = utils.add_transparent_image(np.array(image), saliency, 0.7) | |
| saliency = utils.label_frame(saliency, token_text) | |
| saliency_images.append(saliency) | |
| token_texts.append(token_text) | |
| return saliency_images, token_texts | |
| def process_image_donut(image): | |
| try: | |
| model, processor = initialize_donut() | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda") | |
| pixel_values.requires_grad = True | |
| task_prompt = "<s_cord-v2>" | |
| decoder_input_ids = processor.tokenizer( | |
| task_prompt, add_special_tokens=False, return_tensors="pt" | |
| )["input_ids"].to("cuda") | |
| outputs = model.generate.__wrapped__( | |
| model, | |
| pixel_values, | |
| decoder_input_ids=decoder_input_ids, | |
| max_length=model.decoder.config.max_position_embeddings, | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| saliency_images, token_texts = compute_saliency(outputs, pixel_values, processor, image) | |
| sequence = processor.batch_decode(outputs.sequences)[0] | |
| sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( | |
| processor.tokenizer.pad_token, "" | |
| ) | |
| sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() | |
| result = processor.token2json(sequence) | |
| return saliency_images, json.dumps(result, indent=2) | |
| except Exception as e: | |
| logger.error(f"Error processing image with Donut: {str(e)}") | |
| return None, f"Error: {str(e)}" | |
| def process_image(model_name, image=None, dataset_image_index=None): | |
| if dataset_image_index is not None: | |
| image = get_image_from_dataset(dataset_image_index) | |
| if model_name == "de-Rodrigo/donut-merit": | |
| saliency_images, result = process_image_donut(image) | |
| else: | |
| # Aquí deberías implementar el procesamiento para otros modelos | |
| saliency_images, result = None, f"Processing for model {model_name} not implemented" | |
| return saliency_images, result | |
| def update_image(dataset_image_index): | |
| return get_image_from_dataset(dataset_image_index) | |
| if __name__ == "__main__": | |
| # Load the dataset | |
| load_merit_dataset() | |
| models = get_collection_models("saliency") | |
| models.append("de-Rodrigo/donut-merit") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Saliency Maps with the MERIT Dataset 🎒📃🏆") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Image(value=README_IMAGE_PATH, height=400) | |
| with gr.Column(scale=1): | |
| gr.Image( | |
| value=GIF_PATH, label="Dataset samples you can process", height=400 | |
| ) | |
| with gr.Tab("Introduction"): | |
| gr.Markdown( | |
| """ | |
| ## Welcome to Saliency Maps with the [MERIT Dataset](https://huggingface.co/datasets/de-Rodrigo/merit) 🎒📃🏆 | |
| This space demonstrates the capabilities of different Vision Language models | |
| for document understanding tasks. | |
| ### Key Features: | |
| - Process images from the [MERIT Dataset](https://huggingface.co/datasets/de-Rodrigo/merit) or upload your own image. | |
| - Use a fine-tuned version of the models availabe to extract grades from documents. | |
| - Visualize saliency maps to understand where the model is looking (WIP 🛠️). | |
| """ | |
| ) | |
| with gr.Tab("Try It Yourself"): | |
| gr.Markdown( | |
| "Select a model and an image from the dataset, or upload your own image." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_dropdown = gr.Dropdown(choices=models, label="Select Model") | |
| dataset_slider = gr.Slider( | |
| minimum=0, | |
| maximum=len(dataset) - 1, | |
| step=1, | |
| label="Dataset Image Index", | |
| ) | |
| upload_image = gr.Image( | |
| type="pil", label="Or Upload Your Own Image" | |
| ) | |
| preview_image = gr.Image(label="Selected/Uploaded Image") | |
| process_button = gr.Button("Process Image") | |
| with gr.Row(): | |
| output_image = gr.Gallery(label="Processed Saliency Images") | |
| output_text = gr.Textbox(label="Result") | |
| # Update preview image when slider changes | |
| dataset_slider.change( | |
| fn=update_image, inputs=[dataset_slider], outputs=[preview_image] | |
| ) | |
| # Update preview image when an image is uploaded | |
| upload_image.change( | |
| fn=lambda x: x, inputs=[upload_image], outputs=[preview_image] | |
| ) | |
| # Process image when button is clicked | |
| process_button.click( | |
| fn=process_image, | |
| inputs=[model_dropdown, upload_image, dataset_slider], | |
| outputs=[output_image, output_text], | |
| ) | |
| demo.launch() | |