์ ๋ก์ท(zero-shot) ์ด๋ฏธ์ง ๋ถ๋ฅ[[zeroshot-image-classification]]
[[open-in-colab]]
์ ๋ก์ท(zero-shot) ์ด๋ฏธ์ง ๋ถ๋ฅ๋ ํน์ ์นดํ ๊ณ ๋ฆฌ์ ์์๊ฐ ํฌํจ๋ ๋ฐ์ดํฐ๋ฅผ ํ์ต๋์ง ์์ ๋ชจ๋ธ์ ์ฌ์ฉํด ์ด๋ฏธ์ง ๋ถ๋ฅ๋ฅผ ์ํํ๋ ์์ ์ ๋๋ค.
์ผ๋ฐ์ ์ผ๋ก ์ด๋ฏธ์ง ๋ถ๋ฅ๋ฅผ ์ํด์๋ ๋ ์ด๋ธ์ด ๋ฌ๋ฆฐ ํน์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ก ๋ชจ๋ธ ํ์ต์ด ํ์ํ๋ฉฐ, ์ด ๋ชจ๋ธ์ ํน์ ์ด๋ฏธ์ง์ ํน์ง์ ๋ ์ด๋ธ์ "๋งคํ"ํ๋ ๋ฐฉ๋ฒ์ ํ์ตํฉ๋๋ค. ์๋ก์ด ๋ ์ด๋ธ์ด ์๋ ๋ถ๋ฅ ์์ ์ ์ด๋ฌํ ๋ชจ๋ธ์ ์ฌ์ฉํด์ผ ํ๋ ๊ฒฝ์ฐ์๋, ๋ชจ๋ธ์ "์ฌ๋ณด์ "ํ๊ธฐ ์ํด ๋ฏธ์ธ ์กฐ์ ์ด ํ์ํฉ๋๋ค.
์ด์ ๋์กฐ์ ์ผ๋ก, ์ ๋ก์ท ๋๋ ๊ฐ๋ฐฉํ ์ดํ(open vocabulary) ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก ๋๊ท๋ชจ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ํด๋น ์ค๋ช ์ ๋ํด ํ์ต๋ ๋ฉํฐ๋ชจ๋ฌ(multimodal) ๋ชจ๋ธ์ ๋๋ค. ์ด๋ฌํ ๋ชจ๋ธ์ ์ ๋ก์ท ์ด๋ฏธ์ง ๋ถ๋ฅ๋ฅผ ํฌํจํ ๋ง์ ๋ค์ด์คํธ๋ฆผ ์์ ์ ์ฌ์ฉํ ์ ์๋ ์ ๋ ฌ๋(aligned) ๋น์ ์ธ์ด ํํ์ ํ์ตํฉ๋๋ค.
์ด๋ ์ด๋ฏธ์ง ๋ถ๋ฅ์ ๋ํ ๋ณด๋ค ์ ์ฐํ ์ ๊ทผ ๋ฐฉ์์ผ๋ก, ์ถ๊ฐ ํ์ต ๋ฐ์ดํฐ ์์ด ์๋ก์ด ๋ ์ด๋ธ์ด๋ ํ์ตํ์ง ๋ชปํ ์นดํ ๊ณ ๋ฆฌ์ ๋ํด ๋ชจ๋ธ์ ์ผ๋ฐํํ ์ ์์ต๋๋ค. ๋ํ, ์ฌ์ฉ์๊ฐ ๋์ ๊ฐ์ฒด์ ๋ํ ์์ ํ์์ ํ ์คํธ ์ค๋ช ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ๊ฒ์ํ ์ ์์ต๋๋ค.
์ด๋ฒ ๊ฐ์ด๋์์ ๋ฐฐ์ธ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์ ๋ก์ท ์ด๋ฏธ์ง ๋ถ๋ฅ ํ์ดํ๋ผ์ธ ๋ง๋ค๊ธฐ
- ์ง์ ์ ๋ก์ท ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ ์ถ๋ก ์คํํ๊ธฐ
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๋ชจ๋ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
pip install -q transformers
์ ๋ก์ท(zero-shot) ์ด๋ฏธ์ง ๋ถ๋ฅ ํ์ดํ๋ผ์ธ[[zeroshot-image-classification-pipeline]]
[pipeline
]์ ํ์ฉํ๋ฉด ๊ฐ์ฅ ๊ฐ๋จํ๊ฒ ์ ๋ก์ท ์ด๋ฏธ์ง ๋ถ๋ฅ๋ฅผ ์ง์ํ๋ ๋ชจ๋ธ๋ก ์ถ๋ก ํด๋ณผ ์ ์์ต๋๋ค.
Hugging Face Hub์ ์
๋ก๋๋ ์ฒดํฌํฌ์ธํธ์์ ํ์ดํ๋ผ์ธ์ ์ธ์คํด์คํํฉ๋๋ค.
>>> from transformers import pipeline
>>> checkpoint = "openai/clip-vit-large-patch14"
>>> detector = pipeline(model=checkpoint, task="zero-shot-image-classification")
๋ค์์ผ๋ก, ๋ถ๋ฅํ๊ณ ์ถ์ ์ด๋ฏธ์ง๋ฅผ ์ ํํ์ธ์.
>>> from PIL import Image
>>> import requests
>>> url = "https://unsplash.com/photos/g8oS8-82DxI/download?ixid=MnwxMjA3fDB8MXx0b3BpY3x8SnBnNktpZGwtSGt8fHx8fDJ8fDE2NzgxMDYwODc&force=true&w=640"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image

์ด๋ฏธ์ง์ ํด๋น ์ด๋ฏธ์ง์ ํ๋ณด ๋ ์ด๋ธ์ธ candidate_labels
๋ฅผ ํ์ดํ๋ผ์ธ์ผ๋ก ์ ๋ฌํฉ๋๋ค.
์ฌ๊ธฐ์๋ ์ด๋ฏธ์ง๋ฅผ ์ง์ ์ ๋ฌํ์ง๋ง, ์ปดํจํฐ์ ์ ์ฅ๋ ์ด๋ฏธ์ง์ ๊ฒฝ๋ก๋ url๋ก ์ ๋ฌํ ์๋ ์์ต๋๋ค.
candidate_labels
๋ ์ด ์์์ฒ๋ผ ๊ฐ๋จํ ๋จ์ด์ผ ์๋ ์๊ณ ์ข ๋ ์ค๋ช
์ ์ธ ๋จ์ด์ผ ์๋ ์์ต๋๋ค.
>>> predictions = classifier(image, candidate_labels=["fox", "bear", "seagull", "owl"])
>>> predictions
[{'score': 0.9996670484542847, 'label': 'owl'},
{'score': 0.000199399160919711, 'label': 'seagull'},
{'score': 7.392891711788252e-05, 'label': 'fox'},
{'score': 5.96074532950297e-05, 'label': 'bear'}]
์ง์ ์ ๋ก์ท(zero-shot) ์ด๋ฏธ์ง ๋ถ๋ฅํ๊ธฐ[[zeroshot-image-classification-by-hand]]
์ด์ ์ ๋ก์ท ์ด๋ฏธ์ง ๋ถ๋ฅ ํ์ดํ๋ผ์ธ ์ฌ์ฉ ๋ฐฉ๋ฒ์ ์ดํด๋ณด์์ผ๋, ์คํํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
Hugging Face Hub์ ์ ๋ก๋๋ ์ฒดํฌํฌ์ธํธ์์ ๋ชจ๋ธ๊ณผ ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ์ผ๋ก ์์ํฉ๋๋ค. ์ฌ๊ธฐ์๋ ์ด์ ๊ณผ ๋์ผํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ๊ฒ ์ต๋๋ค:
>>> from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
>>> model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint)
>>> processor = AutoProcessor.from_pretrained(checkpoint)
๋ค๋ฅธ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค.
>>> from PIL import Image
>>> import requests
>>> url = "https://unsplash.com/photos/xBRQfR2bqNI/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjc4Mzg4ODEx&force=true&w=640"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image

ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ์ ๋ ฅ์ ์ค๋นํฉ๋๋ค. ํ๋ก์ธ์๋ ๋ชจ๋ธ์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๊ธฐ ์ํด ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ ๋ณํํ๊ณ ์ ๊ทํํ๋ ์ด๋ฏธ์ง ํ๋ก์ธ์์ ํ ์คํธ ์ ๋ ฅ์ ์ฒ๋ฆฌํ๋ ํ ํฌ๋์ด์ ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
>>> candidate_labels = ["tree", "car", "bike", "cat"]
>>> inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True)
๋ชจ๋ธ์ ์ ๋ ฅ์ ์ ๋ฌํ๊ณ , ๊ฒฐ๊ณผ๋ฅผ ํ์ฒ๋ฆฌํฉ๋๋ค:
>>> import torch
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits = outputs.logits_per_image[0]
>>> probs = logits.softmax(dim=-1).numpy()
>>> scores = probs.tolist()
>>> result = [
... {"score": score, "label": candidate_label}
... for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
... ]
>>> result
[{'score': 0.998572, 'label': 'car'},
{'score': 0.0010570387, 'label': 'bike'},
{'score': 0.0003393686, 'label': 'tree'},
{'score': 3.1572064e-05, 'label': 'cat'}]