MASR / transformers /docs /source /ko /tasks /zero_shot_object_detection.md
Yuvarraj's picture
Initial commit
a0db2f9

์ œ๋กœ์ƒท(zero-shot) ๊ฐ์ฒด ํƒ์ง€[[zeroshot-object-detection]]

[[open-in-colab]]

์ผ๋ฐ˜์ ์œผ๋กœ ๊ฐ์ฒด ํƒ์ง€์— ์‚ฌ์šฉ๋˜๋Š” ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ ˆ์ด๋ธ”์ด ์ง€์ •๋œ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ ์„ธํŠธ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ํ•™์Šต ๋ฐ์ดํ„ฐ์— ์กด์žฌํ•˜๋Š” ํด๋ž˜์Šค(๋ ˆ์ด๋ธ”)๋งŒ ํƒ์ง€ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ํ•œ๊ณ„์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค๋ฅธ ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๋Š” OWL-ViT ๋ชจ๋ธ๋กœ ์ œ๋กœ์ƒท ๊ฐ์ฒด ํƒ์ง€๊ฐ€ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. OWL-ViT๋Š” ๊ฐœ๋ฐฉํ˜• ์–ดํœ˜(open-vocabulary) ๊ฐ์ฒด ํƒ์ง€๊ธฐ์ž…๋‹ˆ๋‹ค. ์ฆ‰, ๋ ˆ์ด๋ธ”์ด ์ง€์ •๋œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ๋ฏธ์„ธ ์กฐ์ •ํ•˜์ง€ ์•Š๊ณ  ์ž์œ  ํ…์ŠคํŠธ ์ฟผ๋ฆฌ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ด๋ฏธ์ง€์—์„œ ๊ฐ์ฒด๋ฅผ ํƒ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

OWL-ViT ๋ชจ๋ธ์€ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ํ‘œํ˜„์„ ํ™œ์šฉํ•ด ๊ฐœ๋ฐฉํ˜• ์–ดํœ˜ ํƒ์ง€(open-vocabulary detection)๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. CLIP ๋ชจ๋ธ์— ๊ฒฝ๋Ÿ‰ํ™”(lightweight)๋œ ๊ฐ์ฒด ๋ถ„๋ฅ˜์™€ ์ง€์—ญํ™”(localization) ํ—ค๋“œ๋ฅผ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค. ๊ฐœ๋ฐฉํ˜• ์–ดํœ˜ ํƒ์ง€๋Š” CLIP์˜ ํ…์ŠคํŠธ ์ธ์ฝ”๋”๋กœ free-text ์ฟผ๋ฆฌ๋ฅผ ์ž„๋ฒ ๋”ฉํ•˜๊ณ , ๊ฐ์ฒด ๋ถ„๋ฅ˜์™€ ์ง€์—ญํ™” ํ—ค๋“œ์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€์™€ ํ•ด๋‹น ํ…์ŠคํŠธ ์„ค๋ช…์„ ์—ฐ๊ฒฐํ•˜๋ฉด ViT๊ฐ€ ์ด๋ฏธ์ง€ ํŒจ์น˜(image patches)๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. OWL-ViT ๋ชจ๋ธ์˜ ์ €์ž๋“ค์€ CLIP ๋ชจ๋ธ์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ํ•™์Šต(scratch learning)ํ•œ ํ›„์—, bipartite matching loss๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ‘œ์ค€ ๊ฐ์ฒด ์ธ์‹ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ OWL-ViT ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ–ˆ์Šต๋‹ˆ๋‹ค.

์ด ์ ‘๊ทผ ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋ธ์€ ๋ ˆ์ด๋ธ”์ด ์ง€์ •๋œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ๋Œ€ํ•œ ์‚ฌ์ „ ํ•™์Šต ์—†์ด๋„ ํ…์ŠคํŠธ ์„ค๋ช…์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฐ์ฒด๋ฅผ ํƒ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋ฒˆ ๊ฐ€์ด๋“œ์—์„œ๋Š” OWL-ViT ๋ชจ๋ธ์˜ ์‚ฌ์šฉ๋ฒ•์„ ๋‹ค๋ฃฐ ๊ฒƒ์ž…๋‹ˆ๋‹ค:

  • ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ฐ˜ ๊ฐ์ฒด ํƒ์ง€
  • ์ผ๊ด„ ๊ฐ์ฒด ํƒ์ง€
  • ์ด๋ฏธ์ง€ ๊ฐ€์ด๋“œ ๊ฐ์ฒด ํƒ์ง€

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ๋ชจ๋‘ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

pip install -q transformers

์ œ๋กœ์ƒท(zero-shot) ๊ฐ์ฒด ํƒ์ง€ ํŒŒ์ดํ”„๋ผ์ธ[[zeroshot-object-detection-pipeline]]

[pipeline]์„ ํ™œ์šฉํ•˜๋ฉด ๊ฐ€์žฅ ๊ฐ„๋‹จํ•˜๊ฒŒ OWL-ViT ๋ชจ๋ธ์„ ์ถ”๋ก ํ•ด๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. Hugging Face Hub์— ์—…๋กœ๋“œ๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ ์ œ๋กœ์ƒท(zero-shot) ๊ฐ์ฒด ํƒ์ง€์šฉ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ธ์Šคํ„ด์Šคํ™”ํ•ฉ๋‹ˆ๋‹ค:

>>> from transformers import pipeline

>>> checkpoint = "google/owlvit-base-patch32"
>>> detector = pipeline(model=checkpoint, task="zero-shot-object-detection")

๋‹ค์Œ์œผ๋กœ, ๊ฐ์ฒด๋ฅผ ํƒ์ง€ํ•˜๊ณ  ์‹ถ์€ ์ด๋ฏธ์ง€๋ฅผ ์„ ํƒํ•˜์„ธ์š”. ์—ฌ๊ธฐ์„œ๋Š” NASA Great Images ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ผ๋ถ€์ธ ์šฐ์ฃผ๋น„ํ–‰์‚ฌ ์—์ผ๋ฆฐ ์ฝœ๋ฆฐ์Šค(Eileen Collins) ์‚ฌ์ง„์„ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

>>> import skimage
>>> import numpy as np
>>> from PIL import Image

>>> image = skimage.data.astronaut()
>>> image = Image.fromarray(np.uint8(image)).convert("RGB")

>>> image
Astronaut Eileen Collins

์ด๋ฏธ์ง€์™€ ํ•ด๋‹น ์ด๋ฏธ์ง€์˜ ํ›„๋ณด ๋ ˆ์ด๋ธ”์„ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ง์ ‘ ์ „๋‹ฌํ•˜์ง€๋งŒ, ์ปดํ“จํ„ฐ์— ์ €์žฅ๋œ ์ด๋ฏธ์ง€์˜ ๊ฒฝ๋กœ๋‚˜ url๋กœ ์ „๋‹ฌํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. candidate_labels๋Š” ์ด ์˜ˆ์‹œ์ฒ˜๋Ÿผ ๊ฐ„๋‹จํ•œ ๋‹จ์–ด์ผ ์ˆ˜๋„ ์žˆ๊ณ  ์ข€ ๋” ์„ค๋ช…์ ์ธ ๋‹จ์–ด์ผ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ, ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰(query)ํ•˜๋ ค๋Š” ๋ชจ๋“  ํ•ญ๋ชฉ์— ๋Œ€ํ•œ ํ…์ŠคํŠธ ์„ค๋ช…๋„ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

>>> predictions = detector(
...     image,
...     candidate_labels=["human face", "rocket", "nasa badge", "star-spangled banner"],
... )
>>> predictions
[{'score': 0.3571370542049408,
  'label': 'human face',
  'box': {'xmin': 180, 'ymin': 71, 'xmax': 271, 'ymax': 178}},
 {'score': 0.28099656105041504,
  'label': 'nasa badge',
  'box': {'xmin': 129, 'ymin': 348, 'xmax': 206, 'ymax': 427}},
 {'score': 0.2110239565372467,
  'label': 'rocket',
  'box': {'xmin': 350, 'ymin': -1, 'xmax': 468, 'ymax': 288}},
 {'score': 0.13790413737297058,
  'label': 'star-spangled banner',
  'box': {'xmin': 1, 'ymin': 1, 'xmax': 105, 'ymax': 509}},
 {'score': 0.11950037628412247,
  'label': 'nasa badge',
  'box': {'xmin': 277, 'ymin': 338, 'xmax': 327, 'ymax': 380}},
 {'score': 0.10649408400058746,
  'label': 'rocket',
  'box': {'xmin': 358, 'ymin': 64, 'xmax': 424, 'ymax': 280}}]

์ด์ œ ์˜ˆ์ธก๊ฐ’์„ ์‹œ๊ฐํ™”ํ•ด๋ด…์‹œ๋‹ค:

>>> from PIL import ImageDraw

>>> draw = ImageDraw.Draw(image)

>>> for prediction in predictions:
...     box = prediction["box"]
...     label = prediction["label"]
...     score = prediction["score"]

...     xmin, ymin, xmax, ymax = box.values()
...     draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
...     draw.text((xmin, ymin), f"{label}: {round(score,2)}", fill="white")

>>> image
Visualized predictions on NASA image

ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ฐ˜ ๊ฐ์ฒด ํƒ์ง€[[textprompted-zeroshot-object-detection-by-hand]]

์ œ๋กœ์ƒท ๊ฐ์ฒด ํƒ์ง€ ํŒŒ์ดํ”„๋ผ์ธ ์‚ฌ์šฉ๋ฒ•์— ๋Œ€ํ•ด ์‚ดํŽด๋ณด์•˜์œผ๋‹ˆ, ์ด์ œ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋ณต์ œํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

Hugging Face Hub์— ์—…๋กœ๋“œ๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ ๊ด€๋ จ ๋ชจ๋ธ๊ณผ ํ”„๋กœ์„ธ์„œ๋ฅผ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” ์ด์ „๊ณผ ๋™์ผํ•œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint)
>>> processor = AutoProcessor.from_pretrained(checkpoint)

๋‹ค๋ฅธ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> import requests

>>> url = "https://unsplash.com/photos/oj0zeY2Ltk4/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTR8fHBpY25pY3xlbnwwfHx8fDE2Nzc0OTE1NDk&force=true&w=640"
>>> im = Image.open(requests.get(url, stream=True).raw)
>>> im
Beach photo

ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์˜ ์ž…๋ ฅ์„ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค. ํ”„๋กœ์„ธ์„œ๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ ๋ณ€ํ™˜ํ•˜๊ณ  ์ •๊ทœํ™”ํ•˜๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ์™€ ํ…์ŠคํŠธ ์ž…๋ ฅ์„ ์ฒ˜๋ฆฌํ•˜๋Š” [CLIPTokenizer]๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

>>> text_queries = ["hat", "book", "sunglasses", "camera"]
>>> inputs = processor(text=text_queries, images=im, return_tensors="pt")

๋ชจ๋ธ์— ์ž…๋ ฅ์„ ์ „๋‹ฌํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ํ›„์ฒ˜๋ฆฌ ๋ฐ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๊ฐ€ ๋ชจ๋ธ์— ์ด๋ฏธ์ง€๋ฅผ ์ž…๋ ฅํ•˜๊ธฐ ์ „์— ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์—, [~OwlViTImageProcessor.post_process_object_detection] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ด ์˜ˆ์ธก๊ฐ’์˜ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค(bounding box)๊ฐ€ ์›๋ณธ ์ด๋ฏธ์ง€์˜ ์ขŒํ‘œ์™€ ์ƒ๋Œ€์ ์œผ๋กœ ๋™์ผํ•œ์ง€ ํ™•์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

>>> import torch

>>> with torch.no_grad():
...     outputs = model(**inputs)
...     target_sizes = torch.tensor([im.size[::-1]])
...     results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]

>>> draw = ImageDraw.Draw(im)

>>> scores = results["scores"].tolist()
>>> labels = results["labels"].tolist()
>>> boxes = results["boxes"].tolist()

>>> for box, score, label in zip(boxes, scores, labels):
...     xmin, ymin, xmax, ymax = box
...     draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
...     draw.text((xmin, ymin), f"{text_queries[label]}: {round(score,2)}", fill="white")

>>> im
Beach photo with detected objects

์ผ๊ด„ ์ฒ˜๋ฆฌ[[batch-processing]]

์—ฌ๋Ÿฌ ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ์ฟผ๋ฆฌ๋ฅผ ์ „๋‹ฌํ•˜์—ฌ ์—ฌ๋Ÿฌ ์ด๋ฏธ์ง€์—์„œ ์„œ๋กœ ๋‹ค๋ฅธ(๋˜๋Š” ๋™์ผํ•œ) ๊ฐ์ฒด๋ฅผ ๊ฒ€์ƒ‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ผ๊ด„ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด์„œ ํ…์ŠคํŠธ ์ฟผ๋ฆฌ๋Š” ์ด์ค‘ ๋ฆฌ์ŠคํŠธ๋กœ, ์ด๋ฏธ์ง€๋Š” PIL ์ด๋ฏธ์ง€, PyTorch ํ…์„œ, ๋˜๋Š” NumPy ๋ฐฐ์—ด๋กœ ์ด๋ฃจ์–ด์ง„ ๋ฆฌ์ŠคํŠธ๋กœ ํ”„๋กœ์„ธ์„œ์— ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

>>> images = [image, im]
>>> text_queries = [
...     ["human face", "rocket", "nasa badge", "star-spangled banner"],
...     ["hat", "book", "sunglasses", "camera"],
... ]
>>> inputs = processor(text=text_queries, images=images, return_tensors="pt")

์ด์ „์—๋Š” ํ›„์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด ๋‹จ์ผ ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ๋ฅผ ํ…์„œ๋กœ ์ „๋‹ฌํ–ˆ์ง€๋งŒ, ํŠœํ”Œ์„ ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ๊ณ , ์—ฌ๋Ÿฌ ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒฝ์šฐ์—๋Š” ํŠœํ”Œ๋กœ ์ด๋ฃจ์–ด์ง„ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ „๋‹ฌํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜ ๋‘ ์˜ˆ์ œ์— ๋Œ€ํ•œ ์˜ˆ์ธก์„ ์ƒ์„ฑํ•˜๊ณ , ๋‘ ๋ฒˆ์งธ ์ด๋ฏธ์ง€(image_idx = 1)๋ฅผ ์‹œ๊ฐํ™”ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

>>> with torch.no_grad():
...     outputs = model(**inputs)
...     target_sizes = [x.size[::-1] for x in images]
...     results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)

>>> image_idx = 1
>>> draw = ImageDraw.Draw(images[image_idx])

>>> scores = results[image_idx]["scores"].tolist()
>>> labels = results[image_idx]["labels"].tolist()
>>> boxes = results[image_idx]["boxes"].tolist()

>>> for box, score, label in zip(boxes, scores, labels):
...     xmin, ymin, xmax, ymax = box
...     draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
...     draw.text((xmin, ymin), f"{text_queries[image_idx][label]}: {round(score,2)}", fill="white")

>>> images[image_idx]
Beach photo with detected objects

์ด๋ฏธ์ง€ ๊ฐ€์ด๋“œ ๊ฐ์ฒด ํƒ์ง€[[imageguided-object-detection]]

ํ…์ŠคํŠธ ์ฟผ๋ฆฌ๋ฅผ ์ด์šฉํ•œ ์ œ๋กœ์ƒท ๊ฐ์ฒด ํƒ์ง€ ์™ธ์—๋„ OWL-ViT ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€ ๊ฐ€์ด๋“œ ๊ฐ์ฒด ํƒ์ง€ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€๋ฅผ ์ฟผ๋ฆฌ๋กœ ์‚ฌ์šฉํ•ด ๋Œ€์ƒ ์ด๋ฏธ์ง€์—์„œ ์œ ์‚ฌํ•œ ๊ฐ์ฒด๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ๋‹ค๋Š” ์˜๋ฏธ์ž…๋‹ˆ๋‹ค. ํ…์ŠคํŠธ ์ฟผ๋ฆฌ์™€ ๋‹ฌ๋ฆฌ ํ•˜๋‚˜์˜ ์˜ˆ์ œ ์ด๋ฏธ์ง€์—์„œ๋งŒ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

์†ŒํŒŒ์— ๊ณ ์–‘์ด ๋‘ ๋งˆ๋ฆฌ๊ฐ€ ์žˆ๋Š” ์ด๋ฏธ์ง€๋ฅผ ๋Œ€์ƒ ์ด๋ฏธ์ง€(target image)๋กœ, ๊ณ ์–‘์ด ํ•œ ๋งˆ๋ฆฌ๊ฐ€ ์žˆ๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ฟผ๋ฆฌ๋กœ ์‚ฌ์šฉํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image_target = Image.open(requests.get(url, stream=True).raw)

>>> query_url = "http://images.cocodataset.org/val2017/000000524280.jpg"
>>> query_image = Image.open(requests.get(query_url, stream=True).raw)

๋‹ค์Œ ์ด๋ฏธ์ง€๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> import matplotlib.pyplot as plt

>>> fig, ax = plt.subplots(1, 2)
>>> ax[0].imshow(image_target)
>>> ax[1].imshow(query_image)
Cats

์ „์ฒ˜๋ฆฌ ๋‹จ๊ณ„์—์„œ ํ…์ŠคํŠธ ์ฟผ๋ฆฌ ๋Œ€์‹ ์— query_images๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค:

>>> inputs = processor(images=image_target, query_images=query_image, return_tensors="pt")

์˜ˆ์ธก์˜ ๊ฒฝ์šฐ, ๋ชจ๋ธ์— ์ž…๋ ฅ์„ ์ „๋‹ฌํ•˜๋Š” ๋Œ€์‹  [~OwlViTForObjectDetection.image_guided_detection]์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ๋ ˆ์ด๋ธ”์ด ์—†๋‹ค๋Š” ์ ์„ ์ œ์™ธํ•˜๋ฉด ์ด์ „๊ณผ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. ์ด์ „๊ณผ ๋™์ผํ•˜๊ฒŒ ์ด๋ฏธ์ง€๋ฅผ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

>>> with torch.no_grad():
...     outputs = model.image_guided_detection(**inputs)
...     target_sizes = torch.tensor([image_target.size[::-1]])
...     results = processor.post_process_image_guided_detection(outputs=outputs, target_sizes=target_sizes)[0]

>>> draw = ImageDraw.Draw(image_target)

>>> scores = results["scores"].tolist()
>>> boxes = results["boxes"].tolist()

>>> for box, score, label in zip(boxes, scores, labels):
...     xmin, ymin, xmax, ymax = box
...     draw.rectangle((xmin, ymin, xmax, ymax), outline="white", width=4)

>>> image_target
Cats with bounding boxes

OWL-ViT ๋ชจ๋ธ์„ ์ถ”๋ก ํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ์•„๋ž˜ ๋ฐ๋ชจ๋ฅผ ํ™•์ธํ•˜์„ธ์š”: