File size: 4,281 Bytes
1237317 abac148 70c6b79 1237317 70c6b79 1237317 70c6b79 abac148 70c6b79 abac148 70c6b79 abac148 70c6b79 abac148 70c6b79 abac148 70c6b79 abac148 c3521e3 abac148 1237317 abac148 e50ddcb abac148 c3521e3 abac148 c3521e3 abac148 70c6b79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from PIL import Image
import torch
import matplotlib.pyplot as plt
import cv2
import numpy as np
import spaces
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").cuda()
@spaces.GPU
def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
inputs = processor(
text=prompt, images=image, return_tensors="pt"
)
inputs = {k: v.cuda() for k, v in inputs.items()}
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.squeeze().cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
# normalize the mask
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
# threshold the mask
bmask = mask > threshold
# zero out values below the threshold
mask[mask < threshold] = 0
fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, alpha=alpha_value, cmap="jet")
if draw_rectangles:
contours, hierarchy = cv2.findContours(
bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
rect = plt.Rectangle(
(x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
)
ax.add_patch(rect)
ax.axis("off")
plt.tight_layout()
bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
output_image.paste(image, mask=bmask)
return fig, mask, output_image
title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
with gr.Blocks() as demo:
gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
gr.Markdown(article)
gr.Markdown(description)
gr.Markdown(
"*Example images are taken from the [ImageNet-A](https://paperswithcode.com/dataset/imagenet-a) dataset*"
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
input_prompt = gr.Textbox(label="Please describe what you want to identify")
input_slider_T = gr.Slider(
minimum=0, maximum=1, value=0.4, label="Threshold"
)
input_slider_A = gr.Slider(minimum=0, maximum=1, value=0.5, label="Alpha")
draw_rectangles = gr.Checkbox(label="Draw rectangles")
btn_process = gr.Button("Process")
with gr.Column():
output_plot = gr.Plot(label="Segmentation Result")
output_mask = gr.Image(label="Mask")
output_image = gr.Image(label="Output Image")
btn_process.click(
process_image,
inputs=[
input_image,
input_prompt,
input_slider_T,
input_slider_A,
draw_rectangles,
],
outputs=[output_plot, output_mask, output_image],
)
gr.Examples(
[
["0.003473_cliff _ cliff_0.51112.jpg", "dog", 0.5, 0.5, True],
["0.001861_submarine _ submarine_0.9862991.jpg", "beacon", 0.55, 0.4, True],
["0.004658_spatula _ spatula_0.35416836.jpg", "banana", 0.4, 0.5, True],
],
inputs=[
input_image,
input_prompt,
input_slider_T,
input_slider_A,
draw_rectangles,
],
)
demo.launch(share=True) |