์ ๋ก์ท(zero-shot) ๊ฐ์ฒด ํ์ง[[zeroshot-object-detection]]
[[open-in-colab]]
์ผ๋ฐ์ ์ผ๋ก ๊ฐ์ฒด ํ์ง์ ์ฌ์ฉ๋๋ ๋ชจ๋ธ์ ํ์ตํ๊ธฐ ์ํด์๋ ๋ ์ด๋ธ์ด ์ง์ ๋ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ธํธ๊ฐ ํ์ํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ํ์ต ๋ฐ์ดํฐ์ ์กด์ฌํ๋ ํด๋์ค(๋ ์ด๋ธ)๋ง ํ์งํ ์ ์๋ค๋ ํ๊ณ์ ์ด ์์ต๋๋ค.
๋ค๋ฅธ ๋ฐฉ์์ ์ฌ์ฉํ๋ OWL-ViT ๋ชจ๋ธ๋ก ์ ๋ก์ท ๊ฐ์ฒด ํ์ง๊ฐ ๊ฐ๋ฅํฉ๋๋ค. OWL-ViT๋ ๊ฐ๋ฐฉํ ์ดํ(open-vocabulary) ๊ฐ์ฒด ํ์ง๊ธฐ์ ๋๋ค. ์ฆ, ๋ ์ด๋ธ์ด ์ง์ ๋ ๋ฐ์ดํฐ ์ธํธ์ ๋ฏธ์ธ ์กฐ์ ํ์ง ์๊ณ ์์ ํ ์คํธ ์ฟผ๋ฆฌ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ฏธ์ง์์ ๊ฐ์ฒด๋ฅผ ํ์งํ ์ ์์ต๋๋ค.
OWL-ViT ๋ชจ๋ธ์ ๋ฉํฐ ๋ชจ๋ฌ ํํ์ ํ์ฉํด ๊ฐ๋ฐฉํ ์ดํ ํ์ง(open-vocabulary detection)๋ฅผ ์ํํฉ๋๋ค. CLIP ๋ชจ๋ธ์ ๊ฒฝ๋ํ(lightweight)๋ ๊ฐ์ฒด ๋ถ๋ฅ์ ์ง์ญํ(localization) ํค๋๋ฅผ ๊ฒฐํฉํฉ๋๋ค. ๊ฐ๋ฐฉํ ์ดํ ํ์ง๋ CLIP์ ํ ์คํธ ์ธ์ฝ๋๋ก free-text ์ฟผ๋ฆฌ๋ฅผ ์๋ฒ ๋ฉํ๊ณ , ๊ฐ์ฒด ๋ถ๋ฅ์ ์ง์ญํ ํค๋์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํฉ๋๋ค. ์ด๋ฏธ์ง์ ํด๋น ํ ์คํธ ์ค๋ช ์ ์ฐ๊ฒฐํ๋ฉด ViT๊ฐ ์ด๋ฏธ์ง ํจ์น(image patches)๋ฅผ ์ ๋ ฅ์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค. OWL-ViT ๋ชจ๋ธ์ ์ ์๋ค์ CLIP ๋ชจ๋ธ์ ์ฒ์๋ถํฐ ํ์ต(scratch learning)ํ ํ์, bipartite matching loss๋ฅผ ์ฌ์ฉํ์ฌ ํ์ค ๊ฐ์ฒด ์ธ์ ๋ฐ์ดํฐ์ ์ผ๋ก OWL-ViT ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ต๋๋ค.
์ด ์ ๊ทผ ๋ฐฉ์์ ์ฌ์ฉํ๋ฉด ๋ชจ๋ธ์ ๋ ์ด๋ธ์ด ์ง์ ๋ ๋ฐ์ดํฐ ์ธํธ์ ๋ํ ์ฌ์ ํ์ต ์์ด๋ ํ ์คํธ ์ค๋ช ์ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ์ฒด๋ฅผ ํ์งํ ์ ์์ต๋๋ค.
์ด๋ฒ ๊ฐ์ด๋์์๋ OWL-ViT ๋ชจ๋ธ์ ์ฌ์ฉ๋ฒ์ ๋ค๋ฃฐ ๊ฒ์ ๋๋ค:
- ํ ์คํธ ํ๋กฌํํธ ๊ธฐ๋ฐ ๊ฐ์ฒด ํ์ง
- ์ผ๊ด ๊ฐ์ฒด ํ์ง
- ์ด๋ฏธ์ง ๊ฐ์ด๋ ๊ฐ์ฒด ํ์ง
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๋ชจ๋ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
pip install -q transformers
์ ๋ก์ท(zero-shot) ๊ฐ์ฒด ํ์ง ํ์ดํ๋ผ์ธ[[zeroshot-object-detection-pipeline]]
[pipeline
]์ ํ์ฉํ๋ฉด ๊ฐ์ฅ ๊ฐ๋จํ๊ฒ OWL-ViT ๋ชจ๋ธ์ ์ถ๋ก ํด๋ณผ ์ ์์ต๋๋ค.
Hugging Face Hub์ ์
๋ก๋๋ ์ฒดํฌํฌ์ธํธ์์ ์ ๋ก์ท(zero-shot) ๊ฐ์ฒด ํ์ง์ฉ ํ์ดํ๋ผ์ธ์ ์ธ์คํด์คํํฉ๋๋ค:
>>> from transformers import pipeline
>>> checkpoint = "google/owlvit-base-patch32"
>>> detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
๋ค์์ผ๋ก, ๊ฐ์ฒด๋ฅผ ํ์งํ๊ณ ์ถ์ ์ด๋ฏธ์ง๋ฅผ ์ ํํ์ธ์. ์ฌ๊ธฐ์๋ NASA Great Images ๋ฐ์ดํฐ ์ธํธ์ ์ผ๋ถ์ธ ์ฐ์ฃผ๋นํ์ฌ ์์ผ๋ฆฐ ์ฝ๋ฆฐ์ค(Eileen Collins) ์ฌ์ง์ ์ฌ์ฉํ๊ฒ ์ต๋๋ค.
>>> import skimage
>>> import numpy as np
>>> from PIL import Image
>>> image = skimage.data.astronaut()
>>> image = Image.fromarray(np.uint8(image)).convert("RGB")
>>> image

์ด๋ฏธ์ง์ ํด๋น ์ด๋ฏธ์ง์ ํ๋ณด ๋ ์ด๋ธ์ ํ์ดํ๋ผ์ธ์ผ๋ก ์ ๋ฌํฉ๋๋ค. ์ฌ๊ธฐ์๋ ์ด๋ฏธ์ง๋ฅผ ์ง์ ์ ๋ฌํ์ง๋ง, ์ปดํจํฐ์ ์ ์ฅ๋ ์ด๋ฏธ์ง์ ๊ฒฝ๋ก๋ url๋ก ์ ๋ฌํ ์๋ ์์ต๋๋ค. candidate_labels๋ ์ด ์์์ฒ๋ผ ๊ฐ๋จํ ๋จ์ด์ผ ์๋ ์๊ณ ์ข ๋ ์ค๋ช ์ ์ธ ๋จ์ด์ผ ์๋ ์์ต๋๋ค. ๋ํ, ์ด๋ฏธ์ง๋ฅผ ๊ฒ์(query)ํ๋ ค๋ ๋ชจ๋ ํญ๋ชฉ์ ๋ํ ํ ์คํธ ์ค๋ช ๋ ์ ๋ฌํฉ๋๋ค.
>>> predictions = detector(
... image,
... candidate_labels=["human face", "rocket", "nasa badge", "star-spangled banner"],
... )
>>> predictions
[{'score': 0.3571370542049408,
'label': 'human face',
'box': {'xmin': 180, 'ymin': 71, 'xmax': 271, 'ymax': 178}},
{'score': 0.28099656105041504,
'label': 'nasa badge',
'box': {'xmin': 129, 'ymin': 348, 'xmax': 206, 'ymax': 427}},
{'score': 0.2110239565372467,
'label': 'rocket',
'box': {'xmin': 350, 'ymin': -1, 'xmax': 468, 'ymax': 288}},
{'score': 0.13790413737297058,
'label': 'star-spangled banner',
'box': {'xmin': 1, 'ymin': 1, 'xmax': 105, 'ymax': 509}},
{'score': 0.11950037628412247,
'label': 'nasa badge',
'box': {'xmin': 277, 'ymin': 338, 'xmax': 327, 'ymax': 380}},
{'score': 0.10649408400058746,
'label': 'rocket',
'box': {'xmin': 358, 'ymin': 64, 'xmax': 424, 'ymax': 280}}]
์ด์ ์์ธก๊ฐ์ ์๊ฐํํด๋ด ์๋ค:
>>> from PIL import ImageDraw
>>> draw = ImageDraw.Draw(image)
>>> for prediction in predictions:
... box = prediction["box"]
... label = prediction["label"]
... score = prediction["score"]
... xmin, ymin, xmax, ymax = box.values()
... draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
... draw.text((xmin, ymin), f"{label}: {round(score,2)}", fill="white")
>>> image

ํ ์คํธ ํ๋กฌํํธ ๊ธฐ๋ฐ ๊ฐ์ฒด ํ์ง[[textprompted-zeroshot-object-detection-by-hand]]
์ ๋ก์ท ๊ฐ์ฒด ํ์ง ํ์ดํ๋ผ์ธ ์ฌ์ฉ๋ฒ์ ๋ํด ์ดํด๋ณด์์ผ๋, ์ด์ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ๋ณต์ ํด๋ณด๊ฒ ์ต๋๋ค.
Hugging Face Hub์ ์ ๋ก๋๋ ์ฒดํฌํฌ์ธํธ์์ ๊ด๋ จ ๋ชจ๋ธ๊ณผ ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ์ผ๋ก ์์ํฉ๋๋ค. ์ฌ๊ธฐ์๋ ์ด์ ๊ณผ ๋์ผํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ๊ฒ ์ต๋๋ค:
>>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint)
>>> processor = AutoProcessor.from_pretrained(checkpoint)
๋ค๋ฅธ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค:
>>> import requests
>>> url = "https://unsplash.com/photos/oj0zeY2Ltk4/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTR8fHBpY25pY3xlbnwwfHx8fDE2Nzc0OTE1NDk&force=true&w=640"
>>> im = Image.open(requests.get(url, stream=True).raw)
>>> im

ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ์
๋ ฅ์ ์ค๋นํฉ๋๋ค.
ํ๋ก์ธ์๋ ๋ชจ๋ธ์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๊ธฐ ์ํด ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ ๋ณํํ๊ณ ์ ๊ทํํ๋ ์ด๋ฏธ์ง ํ๋ก์ธ์์ ํ
์คํธ ์
๋ ฅ์ ์ฒ๋ฆฌํ๋ [CLIPTokenizer
]๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
>>> text_queries = ["hat", "book", "sunglasses", "camera"]
>>> inputs = processor(text=text_queries, images=im, return_tensors="pt")
๋ชจ๋ธ์ ์
๋ ฅ์ ์ ๋ฌํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ํ์ฒ๋ฆฌ ๋ฐ ์๊ฐํํฉ๋๋ค.
์ด๋ฏธ์ง ํ๋ก์ธ์๊ฐ ๋ชจ๋ธ์ ์ด๋ฏธ์ง๋ฅผ ์
๋ ฅํ๊ธฐ ์ ์ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ๊ธฐ ๋๋ฌธ์, [~OwlViTImageProcessor.post_process_object_detection
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด
์์ธก๊ฐ์ ๋ฐ์ด๋ฉ ๋ฐ์ค(bounding box)๊ฐ ์๋ณธ ์ด๋ฏธ์ง์ ์ขํ์ ์๋์ ์ผ๋ก ๋์ผํ์ง ํ์ธํด์ผ ํฉ๋๋ค.
>>> import torch
>>> with torch.no_grad():
... outputs = model(**inputs)
... target_sizes = torch.tensor([im.size[::-1]])
... results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
>>> draw = ImageDraw.Draw(im)
>>> scores = results["scores"].tolist()
>>> labels = results["labels"].tolist()
>>> boxes = results["boxes"].tolist()
>>> for box, score, label in zip(boxes, scores, labels):
... xmin, ymin, xmax, ymax = box
... draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
... draw.text((xmin, ymin), f"{text_queries[label]}: {round(score,2)}", fill="white")
>>> im

์ผ๊ด ์ฒ๋ฆฌ[[batch-processing]]
์ฌ๋ฌ ์ด๋ฏธ์ง์ ํ ์คํธ ์ฟผ๋ฆฌ๋ฅผ ์ ๋ฌํ์ฌ ์ฌ๋ฌ ์ด๋ฏธ์ง์์ ์๋ก ๋ค๋ฅธ(๋๋ ๋์ผํ) ๊ฐ์ฒด๋ฅผ ๊ฒ์ํ ์ ์์ต๋๋ค. ์ผ๊ด ์ฒ๋ฆฌ๋ฅผ ์ํด์ ํ ์คํธ ์ฟผ๋ฆฌ๋ ์ด์ค ๋ฆฌ์คํธ๋ก, ์ด๋ฏธ์ง๋ PIL ์ด๋ฏธ์ง, PyTorch ํ ์, ๋๋ NumPy ๋ฐฐ์ด๋ก ์ด๋ฃจ์ด์ง ๋ฆฌ์คํธ๋ก ํ๋ก์ธ์์ ์ ๋ฌํด์ผ ํฉ๋๋ค.
>>> images = [image, im]
>>> text_queries = [
... ["human face", "rocket", "nasa badge", "star-spangled banner"],
... ["hat", "book", "sunglasses", "camera"],
... ]
>>> inputs = processor(text=text_queries, images=images, return_tensors="pt")
์ด์ ์๋ ํ์ฒ๋ฆฌ๋ฅผ ์ํด ๋จ์ผ ์ด๋ฏธ์ง์ ํฌ๊ธฐ๋ฅผ ํ
์๋ก ์ ๋ฌํ์ง๋ง, ํํ์ ์ ๋ฌํ ์ ์๊ณ , ์ฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ ์ฒ๋ฆฌํ๋ ๊ฒฝ์ฐ์๋ ํํ๋ก ์ด๋ฃจ์ด์ง ๋ฆฌ์คํธ๋ฅผ ์ ๋ฌํ ์๋ ์์ต๋๋ค.
์๋ ๋ ์์ ์ ๋ํ ์์ธก์ ์์ฑํ๊ณ , ๋ ๋ฒ์งธ ์ด๋ฏธ์ง(image_idx = 1
)๋ฅผ ์๊ฐํํด ๋ณด๊ฒ ์ต๋๋ค.
>>> with torch.no_grad():
... outputs = model(**inputs)
... target_sizes = [x.size[::-1] for x in images]
... results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)
>>> image_idx = 1
>>> draw = ImageDraw.Draw(images[image_idx])
>>> scores = results[image_idx]["scores"].tolist()
>>> labels = results[image_idx]["labels"].tolist()
>>> boxes = results[image_idx]["boxes"].tolist()
>>> for box, score, label in zip(boxes, scores, labels):
... xmin, ymin, xmax, ymax = box
... draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
... draw.text((xmin, ymin), f"{text_queries[image_idx][label]}: {round(score,2)}", fill="white")
>>> images[image_idx]

์ด๋ฏธ์ง ๊ฐ์ด๋ ๊ฐ์ฒด ํ์ง[[imageguided-object-detection]]
ํ ์คํธ ์ฟผ๋ฆฌ๋ฅผ ์ด์ฉํ ์ ๋ก์ท ๊ฐ์ฒด ํ์ง ์ธ์๋ OWL-ViT ๋ชจ๋ธ์ ์ด๋ฏธ์ง ๊ฐ์ด๋ ๊ฐ์ฒด ํ์ง ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค. ์ด๋ฏธ์ง๋ฅผ ์ฟผ๋ฆฌ๋ก ์ฌ์ฉํด ๋์ ์ด๋ฏธ์ง์์ ์ ์ฌํ ๊ฐ์ฒด๋ฅผ ์ฐพ์ ์ ์๋ค๋ ์๋ฏธ์ ๋๋ค. ํ ์คํธ ์ฟผ๋ฆฌ์ ๋ฌ๋ฆฌ ํ๋์ ์์ ์ด๋ฏธ์ง์์๋ง ๊ฐ๋ฅํฉ๋๋ค.
์ํ์ ๊ณ ์์ด ๋ ๋ง๋ฆฌ๊ฐ ์๋ ์ด๋ฏธ์ง๋ฅผ ๋์ ์ด๋ฏธ์ง(target image)๋ก, ๊ณ ์์ด ํ ๋ง๋ฆฌ๊ฐ ์๋ ์ด๋ฏธ์ง๋ฅผ ์ฟผ๋ฆฌ๋ก ์ฌ์ฉํด๋ณด๊ฒ ์ต๋๋ค:
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image_target = Image.open(requests.get(url, stream=True).raw)
>>> query_url = "http://images.cocodataset.org/val2017/000000524280.jpg"
>>> query_image = Image.open(requests.get(query_url, stream=True).raw)
๋ค์ ์ด๋ฏธ์ง๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค:
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots(1, 2)
>>> ax[0].imshow(image_target)
>>> ax[1].imshow(query_image)

์ ์ฒ๋ฆฌ ๋จ๊ณ์์ ํ
์คํธ ์ฟผ๋ฆฌ ๋์ ์ query_images
๋ฅผ ์ฌ์ฉํฉ๋๋ค:
>>> inputs = processor(images=image_target, query_images=query_image, return_tensors="pt")
์์ธก์ ๊ฒฝ์ฐ, ๋ชจ๋ธ์ ์
๋ ฅ์ ์ ๋ฌํ๋ ๋์ [~OwlViTForObjectDetection.image_guided_detection
]์ ์ ๋ฌํฉ๋๋ค.
๋ ์ด๋ธ์ด ์๋ค๋ ์ ์ ์ ์ธํ๋ฉด ์ด์ ๊ณผ ๋์ผํฉ๋๋ค.
์ด์ ๊ณผ ๋์ผํ๊ฒ ์ด๋ฏธ์ง๋ฅผ ์๊ฐํํฉ๋๋ค.
>>> with torch.no_grad():
... outputs = model.image_guided_detection(**inputs)
... target_sizes = torch.tensor([image_target.size[::-1]])
... results = processor.post_process_image_guided_detection(outputs=outputs, target_sizes=target_sizes)[0]
>>> draw = ImageDraw.Draw(image_target)
>>> scores = results["scores"].tolist()
>>> boxes = results["boxes"].tolist()
>>> for box, score, label in zip(boxes, scores, labels):
... xmin, ymin, xmax, ymax = box
... draw.rectangle((xmin, ymin, xmax, ymax), outline="white", width=4)
>>> image_target

OWL-ViT ๋ชจ๋ธ์ ์ถ๋ก ํ๊ณ ์ถ๋ค๋ฉด ์๋ ๋ฐ๋ชจ๋ฅผ ํ์ธํ์ธ์: