MASR / transformers /docs /source /ko /tasks /zero_shot_image_classification.md
Yuvarraj's picture
Initial commit
a0db2f9

์ œ๋กœ์ƒท(zero-shot) ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜[[zeroshot-image-classification]]

[[open-in-colab]]

์ œ๋กœ์ƒท(zero-shot) ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜๋Š” ํŠน์ • ์นดํ…Œ๊ณ ๋ฆฌ์˜ ์˜ˆ์‹œ๊ฐ€ ํฌํ•จ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šต๋˜์ง€ ์•Š์€ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ์ž‘์—…์ž…๋‹ˆ๋‹ค.

์ผ๋ฐ˜์ ์œผ๋กœ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜๋ฅผ ์œ„ํ•ด์„œ๋Š” ๋ ˆ์ด๋ธ”์ด ๋‹ฌ๋ฆฐ ํŠน์ • ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋กœ ๋ชจ๋ธ ํ•™์Šต์ด ํ•„์š”ํ•˜๋ฉฐ, ์ด ๋ชจ๋ธ์€ ํŠน์ • ์ด๋ฏธ์ง€์˜ ํŠน์ง•์„ ๋ ˆ์ด๋ธ”์— "๋งคํ•‘"ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. ์ƒˆ๋กœ์šด ๋ ˆ์ด๋ธ”์ด ์žˆ๋Š” ๋ถ„๋ฅ˜ ์ž‘์—…์— ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ์—๋Š”, ๋ชจ๋ธ์„ "์žฌ๋ณด์ •"ํ•˜๊ธฐ ์œ„ํ•ด ๋ฏธ์„ธ ์กฐ์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์ด์™€ ๋Œ€์กฐ์ ์œผ๋กœ, ์ œ๋กœ์ƒท ๋˜๋Š” ๊ฐœ๋ฐฉํ˜• ์–ดํœ˜(open vocabulary) ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ๋Œ€๊ทœ๋ชจ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์™€ ํ•ด๋‹น ์„ค๋ช…์— ๋Œ€ํ•ด ํ•™์Šต๋œ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ(multimodal) ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์€ ์ œ๋กœ์ƒท ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜๋ฅผ ํฌํ•จํ•œ ๋งŽ์€ ๋‹ค์šด์ŠคํŠธ๋ฆผ ์ž‘์—…์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ์ •๋ ฌ๋œ(aligned) ๋น„์ „ ์–ธ์–ด ํ‘œํ˜„์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

์ด๋Š” ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜์— ๋Œ€ํ•œ ๋ณด๋‹ค ์œ ์—ฐํ•œ ์ ‘๊ทผ ๋ฐฉ์‹์œผ๋กœ, ์ถ”๊ฐ€ ํ•™์Šต ๋ฐ์ดํ„ฐ ์—†์ด ์ƒˆ๋กœ์šด ๋ ˆ์ด๋ธ”์ด๋‚˜ ํ•™์Šตํ•˜์ง€ ๋ชปํ•œ ์นดํ…Œ๊ณ ๋ฆฌ์— ๋Œ€ํ•ด ๋ชจ๋ธ์„ ์ผ๋ฐ˜ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ, ์‚ฌ์šฉ์ž๊ฐ€ ๋Œ€์ƒ ๊ฐœ์ฒด์— ๋Œ€ํ•œ ์ž์œ  ํ˜•์‹์˜ ํ…์ŠคํŠธ ์„ค๋ช…์œผ๋กœ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋ฒˆ ๊ฐ€์ด๋“œ์—์„œ ๋ฐฐ์šธ ๋‚ด์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • ์ œ๋กœ์ƒท ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ํŒŒ์ดํ”„๋ผ์ธ ๋งŒ๋“ค๊ธฐ
  • ์ง์ ‘ ์ œ๋กœ์ƒท ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ์ถ”๋ก  ์‹คํ–‰ํ•˜๊ธฐ

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ๋ชจ๋‘ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

pip install -q transformers

์ œ๋กœ์ƒท(zero-shot) ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ํŒŒ์ดํ”„๋ผ์ธ[[zeroshot-image-classification-pipeline]]

[pipeline]์„ ํ™œ์šฉํ•˜๋ฉด ๊ฐ€์žฅ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ œ๋กœ์ƒท ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜๋ฅผ ์ง€์›ํ•˜๋Š” ๋ชจ๋ธ๋กœ ์ถ”๋ก ํ•ด๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. Hugging Face Hub์— ์—…๋กœ๋“œ๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ธ์Šคํ„ด์Šคํ™”ํ•ฉ๋‹ˆ๋‹ค.

>>> from transformers import pipeline

>>> checkpoint = "openai/clip-vit-large-patch14"
>>> detector = pipeline(model=checkpoint, task="zero-shot-image-classification")

๋‹ค์Œ์œผ๋กœ, ๋ถ„๋ฅ˜ํ•˜๊ณ  ์‹ถ์€ ์ด๋ฏธ์ง€๋ฅผ ์„ ํƒํ•˜์„ธ์š”.

>>> from PIL import Image
>>> import requests

>>> url = "https://unsplash.com/photos/g8oS8-82DxI/download?ixid=MnwxMjA3fDB8MXx0b3BpY3x8SnBnNktpZGwtSGt8fHx8fDJ8fDE2NzgxMDYwODc&force=true&w=640"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image
Photo of an owl

์ด๋ฏธ์ง€์™€ ํ•ด๋‹น ์ด๋ฏธ์ง€์˜ ํ›„๋ณด ๋ ˆ์ด๋ธ”์ธ candidate_labels๋ฅผ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ง์ ‘ ์ „๋‹ฌํ•˜์ง€๋งŒ, ์ปดํ“จํ„ฐ์— ์ €์žฅ๋œ ์ด๋ฏธ์ง€์˜ ๊ฒฝ๋กœ๋‚˜ url๋กœ ์ „๋‹ฌํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. candidate_labels๋Š” ์ด ์˜ˆ์‹œ์ฒ˜๋Ÿผ ๊ฐ„๋‹จํ•œ ๋‹จ์–ด์ผ ์ˆ˜๋„ ์žˆ๊ณ  ์ข€ ๋” ์„ค๋ช…์ ์ธ ๋‹จ์–ด์ผ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> predictions = classifier(image, candidate_labels=["fox", "bear", "seagull", "owl"])
>>> predictions
[{'score': 0.9996670484542847, 'label': 'owl'},
 {'score': 0.000199399160919711, 'label': 'seagull'},
 {'score': 7.392891711788252e-05, 'label': 'fox'},
 {'score': 5.96074532950297e-05, 'label': 'bear'}]

์ง์ ‘ ์ œ๋กœ์ƒท(zero-shot) ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ํ•˜๊ธฐ[[zeroshot-image-classification-by-hand]]

์ด์ œ ์ œ๋กœ์ƒท ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ํŒŒ์ดํ”„๋ผ์ธ ์‚ฌ์šฉ ๋ฐฉ๋ฒ•์„ ์‚ดํŽด๋ณด์•˜์œผ๋‹ˆ, ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

Hugging Face Hub์— ์—…๋กœ๋“œ๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ ๋ชจ๋ธ๊ณผ ํ”„๋กœ์„ธ์„œ๋ฅผ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” ์ด์ „๊ณผ ๋™์ผํ•œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> from transformers import AutoProcessor, AutoModelForZeroShotImageClassification

>>> model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint)
>>> processor = AutoProcessor.from_pretrained(checkpoint)

๋‹ค๋ฅธ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

>>> from PIL import Image
>>> import requests

>>> url = "https://unsplash.com/photos/xBRQfR2bqNI/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjc4Mzg4ODEx&force=true&w=640"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image
Photo of a car

ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์˜ ์ž…๋ ฅ์„ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค. ํ”„๋กœ์„ธ์„œ๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ ๋ณ€ํ™˜ํ•˜๊ณ  ์ •๊ทœํ™”ํ•˜๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ์™€ ํ…์ŠคํŠธ ์ž…๋ ฅ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ํ† ํฌ๋‚˜์ด์ €๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

>>> candidate_labels = ["tree", "car", "bike", "cat"]
>>> inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True)

๋ชจ๋ธ์— ์ž…๋ ฅ์„ ์ „๋‹ฌํ•˜๊ณ , ๊ฒฐ๊ณผ๋ฅผ ํ›„์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค:

>>> import torch

>>> with torch.no_grad():
...     outputs = model(**inputs)

>>> logits = outputs.logits_per_image[0]
>>> probs = logits.softmax(dim=-1).numpy()
>>> scores = probs.tolist()

>>> result = [
...     {"score": score, "label": candidate_label}
...     for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
... ]

>>> result
[{'score': 0.998572, 'label': 'car'},
 {'score': 0.0010570387, 'label': 'bike'},
 {'score': 0.0003393686, 'label': 'tree'},
 {'score': 3.1572064e-05, 'label': 'cat'}]