import os
from glob import glob
from typing import Optional
import gradio as gr
import torch
from torchvision.transforms.functional import resize, to_pil_image
from transformers import AutoModel, CLIPProcessor
PAPER_TITLE = "Vocabulary-free Image Classification"
PAPER_URL = "https://arxiv.org/abs/2306.00917"
MARKDOWN_DESCRIPTION = """
Vocabulary-free Image Classification
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE)
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
def image_preprocess(image: gr.Image):
if image is None:
return None, None
size = PROCESSOR.image_processor.size["shortest_edge"]
size = min(size) if isinstance(size, tuple) else size
image = resize(image, size)
PROCESSOR.image_processor.do_normalize = False
image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True)
PROCESSOR.image_processor.do_normalize = True
image_tensor = image_tensor.pixel_values[0]
curr_image = to_pil_image(image_tensor)
return curr_image, image.copy()
def image_inference(image: gr.Image, alpha: Optional[float] = None):
if image is None:
return None
images = PROCESSOR(images=[image], return_tensors="pt", padding=True)
with torch.no_grad():
outputs = MODEL(images, alpha=alpha)
vocabulary = outputs["vocabularies"][0]
scores = outputs["scores"][0].tolist()
confidences = dict(zip(vocabulary, scores))
return confidences
with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
# LAYOUT
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
with gr.Column():
curr_image = gr.Image(
label="input", type="pil", sources=["upload", "webcam", "clipboard"]
)
alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha")
with gr.Row():
clear_button = gr.Button(value="Clear", variant="secondary")
run_button = gr.Button(value="Submit", variant="primary")
with gr.Column():
output_label = gr.Label(label="output", num_top_classes=5)
_orig_image = gr.Image(label="original image", type="pil", visible=False, interactive=False)
_example_image = gr.Image(label="example image", type="pil", visible=False, interactive=False)
examples = gr.Examples(
examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")),
inputs=[_example_image],
outputs=[output_label],
fn=image_inference,
cache_examples=True,
)
gr.Markdown(f"Check out the original paper for more information.")
# INTERACTIONS
# - change
_example_image.change(image_preprocess, [_example_image], [curr_image, _orig_image])
# - upload
curr_image.upload(image_preprocess, [curr_image], [curr_image, _orig_image])
curr_image.upload(lambda: None, [], [output_label])
# - clear
curr_image.clear(lambda: (None, None), [], [_orig_image, output_label])
# - click
clear_button.click(lambda: (None, None, None), [], [curr_image, _orig_image, output_label])
run_button.click(image_inference, [curr_image, alpha_slider], [output_label])
if __name__ == "__main__":
demo.queue()
demo.launch()