Spaces:
Running
Running
File size: 4,685 Bytes
563a829 a3ee979 fa73381 50c0cb3 a3ee979 563a829 8994e5d a3ee979 8994e5d 2eca6de 8994e5d 8c43e37 a3ee979 8994e5d a3ee979 563a829 a3ee979 563a829 54a3362 a3ee979 04dd7ac fa73381 563a829 a3ee979 563a829 a3ee979 563a829 50c0cb3 f6c2567 a3ee979 563a829 0da80f6 563a829 04dd7ac 563a829 0da80f6 563a829 04dd7ac 563a829 04dd7ac 563a829 0da80f6 04dd7ac 0da80f6 04dd7ac 0da80f6 563a829 0da80f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
from glob import glob
from typing import Optional
import gradio as gr
import torch
from torchvision.transforms.functional import resize, to_pil_image
from transformers import AutoModel, CLIPProcessor
PAPER_TITLE = "Vocabulary-free Image Classification"
PAPER_URL = "https://arxiv.org/abs/2306.00917"
MARKDOWN_DESCRIPTION = """
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
<h1>Vocabulary-free Image Classification</h1>
</div>
<div style="display: flex;
flex-wrap: wrap;
align-items: center;
justify-content: center;
margin-bottom: 1rem;">
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem; margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
</a>
<a href="https://huggingface.co/spaces/altndrr/vic" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/demo-hf.altndrr%2Fvic-yellow.svg"/>
</a>
<a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/paper-arXiv.2306.00917-B31B1B.svg"/>
</a>
<a href="https://alessandroconti.me/papers/2306.00917" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
</a>
</div>
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE)
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
def image_preprocess(image: gr.Image):
if image is None:
return None, None
size = PROCESSOR.image_processor.size["shortest_edge"]
size = min(size) if isinstance(size, tuple) else size
image = resize(image, size)
PROCESSOR.image_processor.do_normalize = False
image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True)
PROCESSOR.image_processor.do_normalize = True
image_tensor = image_tensor.pixel_values[0]
curr_image = to_pil_image(image_tensor)
return curr_image, image.copy()
def image_inference(image: gr.Image, alpha: Optional[float] = None):
if image is None:
return None
images = PROCESSOR(images=[image], return_tensors="pt", padding=True)
with torch.no_grad():
outputs = MODEL(images, alpha=alpha)
vocabulary = outputs["vocabularies"][0]
scores = outputs["scores"][0].tolist()
confidences = dict(zip(vocabulary, scores))
return confidences
with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
# LAYOUT
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
with gr.Column():
curr_image = gr.Image(
label="input", type="pil", sources=["upload", "webcam", "clipboard"]
)
alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha")
with gr.Row():
clear_button = gr.Button(value="Clear", variant="secondary")
run_button = gr.Button(value="Submit", variant="primary")
with gr.Column():
output_label = gr.Label(label="output", num_top_classes=5)
_orig_image = gr.Image(label="original image", type="pil", visible=False, interactive=False)
_example_image = gr.Image(label="example image", type="pil", visible=False, interactive=False)
examples = gr.Examples(
examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")),
inputs=[_example_image],
outputs=[output_label],
fn=image_inference,
cache_examples=True,
)
gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.")
# INTERACTIONS
# - change
_example_image.change(image_preprocess, [_example_image], [curr_image, _orig_image])
# - upload
curr_image.upload(image_preprocess, [curr_image], [curr_image, _orig_image])
curr_image.upload(lambda: None, [], [output_label])
# - clear
curr_image.clear(lambda: (None, None), [], [_orig_image, output_label])
# - click
clear_button.click(lambda: (None, None, None), [], [curr_image, _orig_image, output_label])
run_button.click(image_inference, [curr_image, alpha_slider], [output_label])
if __name__ == "__main__":
demo.queue()
demo.launch()
|