Spaces:
Running
Running
import os | |
from glob import glob | |
from typing import Optional | |
import gradio as gr | |
import torch | |
from torchvision.transforms.functional import resize, to_pil_image | |
from transformers import AutoModel, CLIPProcessor | |
PAPER_TITLE = "Vocabulary-free Image Classification" | |
PAPER_URL = "https://arxiv.org/abs/2306.00917" | |
MARKDOWN_DESCRIPTION = """ | |
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;"> | |
<h1>Vocabulary-free Image Classification</h1> | |
</div> | |
<div style="display: flex; | |
flex-wrap: wrap; | |
align-items: center; | |
justify-content: center; | |
margin-bottom: 1rem;"> | |
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem; margin-bottom: 0.5rem;"> | |
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/> | |
</a> | |
<a href="https://huggingface.co/spaces/altndrr/vic" style="margin-right: 0.5rem; | |
margin-bottom: 0.5rem;"> | |
<img src="https://img.shields.io/badge/demo-hf.altndrr%2Fvic-yellow.svg"/> | |
</a> | |
<a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem; | |
margin-bottom: 0.5rem;"> | |
<img src="https://img.shields.io/badge/paper-arXiv.2306.00917-B31B1B.svg"/> | |
</a> | |
<a href="https://alessandroconti.me/papers/2306.00917" style="margin-right: 0.5rem; | |
margin-bottom: 0.5rem;"> | |
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/> | |
</a> | |
</div> | |
""" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE) | |
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
def image_preprocess(image: gr.Image): | |
if image is None: | |
return None, None | |
size = PROCESSOR.image_processor.size["shortest_edge"] | |
size = min(size) if isinstance(size, tuple) else size | |
image = resize(image, size) | |
PROCESSOR.image_processor.do_normalize = False | |
image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True) | |
PROCESSOR.image_processor.do_normalize = True | |
image_tensor = image_tensor.pixel_values[0] | |
curr_image = to_pil_image(image_tensor) | |
return curr_image, image.copy() | |
def image_inference(image: gr.Image, alpha: Optional[float] = None): | |
if image is None: | |
return None | |
images = PROCESSOR(images=[image], return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = MODEL(images, alpha=alpha) | |
vocabulary = outputs["vocabularies"][0] | |
scores = outputs["scores"][0].tolist() | |
confidences = dict(zip(vocabulary, scores)) | |
return confidences | |
with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo: | |
# LAYOUT | |
gr.Markdown(MARKDOWN_DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
curr_image = gr.Image( | |
label="input", type="pil", sources=["upload", "webcam", "clipboard"] | |
) | |
alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha") | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear", variant="secondary") | |
run_button = gr.Button(value="Submit", variant="primary") | |
with gr.Column(): | |
output_label = gr.Label(label="output", num_top_classes=5) | |
_orig_image = gr.Image(label="original image", type="pil", visible=False, interactive=False) | |
_example_image = gr.Image(label="example image", type="pil", visible=False, interactive=False) | |
examples = gr.Examples( | |
examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")), | |
inputs=[_example_image], | |
outputs=[output_label], | |
fn=image_inference, | |
cache_examples=True, | |
) | |
gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.") | |
# INTERACTIONS | |
# - change | |
_example_image.change(image_preprocess, [_example_image], [curr_image, _orig_image]) | |
# - upload | |
curr_image.upload(image_preprocess, [curr_image], [curr_image, _orig_image]) | |
curr_image.upload(lambda: None, [], [output_label]) | |
# - clear | |
curr_image.clear(lambda: (None, None), [], [_orig_image, output_label]) | |
# - click | |
clear_button.click(lambda: (None, None, None), [], [curr_image, _orig_image, output_label]) | |
run_button.click(image_inference, [curr_image, alpha_slider], [output_label]) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |