MASR / transformers /docs /source /ko /tasks /video_classification.md
Yuvarraj's picture
Initial commit
a0db2f9

์˜์ƒ ๋ถ„๋ฅ˜ [[video-classification]]

[[open-in-colab]]

์˜์ƒ ๋ถ„๋ฅ˜๋Š” ์˜์ƒ ์ „์ฒด์— ๋ ˆ์ด๋ธ” ๋˜๋Š” ํด๋ž˜์Šค๋ฅผ ์ง€์ •ํ•˜๋Š” ์ž‘์—…์ž…๋‹ˆ๋‹ค. ๊ฐ ์˜์ƒ์—๋Š” ํ•˜๋‚˜์˜ ํด๋ž˜์Šค๊ฐ€ ์žˆ์„ ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒ๋ฉ๋‹ˆ๋‹ค. ์˜์ƒ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์€ ์˜์ƒ์„ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„ ์–ด๋Š ํด๋ž˜์Šค์— ์†ํ•˜๋Š”์ง€์— ๋Œ€ํ•œ ์˜ˆ์ธก์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์€ ์˜์ƒ์ด ์–ด๋–ค ๋‚ด์šฉ์ธ์ง€ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜์ƒ ๋ถ„๋ฅ˜์˜ ์‹ค์ œ ์‘์šฉ ์˜ˆ๋Š” ํ”ผํŠธ๋‹ˆ์Šค ์•ฑ์—์„œ ์œ ์šฉํ•œ ๋™์ž‘ / ์šด๋™ ์ธ์‹ ์„œ๋น„์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๋˜ํ•œ ์‹œ๊ฐ ์žฅ์• ์ธ์ด ์ด๋™ํ•  ๋•Œ ๋ณด์กฐํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋‹ค์Œ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:

  1. UCF101 ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํ•˜์œ„ ์ง‘ํ•ฉ์„ ํ†ตํ•ด VideoMAE ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ.
  2. ๋ฏธ์„ธ ์กฐ์ •ํ•œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•˜๊ธฐ.

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์„ค๋ช…ํ•˜๋Š” ์ž‘์—…์€ ๋‹ค์Œ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์—์„œ ์ง€์›๋ฉ๋‹ˆ๋‹ค:

TimeSformer, VideoMAE

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

pip install -q pytorchvideo transformers evaluate

์˜์ƒ์„ ์ฒ˜๋ฆฌํ•˜๊ณ  ์ค€๋น„ํ•˜๊ธฐ ์œ„ํ•ด PyTorchVideo(์ดํ•˜ pytorchvideo)๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์ปค๋ฎค๋‹ˆํ‹ฐ์— ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๊ณ  ๊ณต์œ ํ•  ์ˆ˜ ์žˆ๋„๋ก Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. ํ”„๋กฌํ”„ํŠธ๊ฐ€ ๋‚˜ํƒ€๋‚˜๋ฉด ํ† ํฐ์„ ์ž…๋ ฅํ•˜์—ฌ ๋กœ๊ทธ์ธํ•˜์„ธ์š”:

>>> from huggingface_hub import notebook_login

>>> notebook_login()

UCF101 ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ [[load-ufc101-dataset]]

UCF-101 ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํ•˜์œ„ ์ง‘ํ•ฉ(subset)์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ํ•™์Šตํ•˜๋Š”๋ฐ ๋” ๋งŽ์€ ์‹œ๊ฐ„์„ ํ• ์• ํ•˜๊ธฐ ์ „์— ๋ฐ์ดํ„ฐ์˜ ํ•˜์œ„ ์ง‘ํ•ฉ์„ ๋ถˆ๋Ÿฌ์™€ ๋ชจ๋“  ๊ฒƒ์ด ์ž˜ ์ž‘๋™ํ•˜๋Š”์ง€ ์‹คํ—˜ํ•˜๊ณ  ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> from huggingface_hub import hf_hub_download

>>> hf_dataset_identifier = "sayakpaul/ucf101-subset"
>>> filename = "UCF101_subset.tar.gz"
>>> file_path = hf_hub_download(repo_id=hf_dataset_identifier, filename=filename, repo_type="dataset")

๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํ•˜์œ„ ์ง‘ํ•ฉ์ด ๋‹ค์šด๋กœ๋“œ ๋˜๋ฉด, ์••์ถ•๋œ ํŒŒ์ผ์˜ ์••์ถ•์„ ํ•ด์ œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

>>> import tarfile

>>> with tarfile.open(file_path) as t:
...      t.extractall(".")

์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

UCF101_subset/
    train/
        BandMarching/
            video_1.mp4
            video_2.mp4
            ...
        Archery
            video_1.mp4
            video_2.mp4
            ...
        ...
    val/
        BandMarching/
            video_1.mp4
            video_2.mp4
            ...
        Archery
            video_1.mp4
            video_2.mp4
            ...
        ...
    test/
        BandMarching/
            video_1.mp4
            video_2.mp4
            ...
        Archery
            video_1.mp4
            video_2.mp4
            ...
        ...

์ •๋ ฌ๋œ ์˜์ƒ์˜ ๊ฒฝ๋กœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

...
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c04.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c06.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c02.avi',
'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c06.avi'
...

๋™์ผํ•œ ๊ทธ๋ฃน/์žฅ๋ฉด์— ์†ํ•˜๋Š” ์˜์ƒ ํด๋ฆฝ์€ ํŒŒ์ผ ๊ฒฝ๋กœ์—์„œ g๋กœ ํ‘œ์‹œ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด, v_ApplyEyeMakeup_g07_c04.avi์™€ v_ApplyEyeMakeup_g07_c06.avi ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋‘˜์€ ๊ฐ™์€ ๊ทธ๋ฃน์ž…๋‹ˆ๋‹ค.

๊ฒ€์ฆ ๋ฐ ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ๋ถ„ํ• ์„ ํ•  ๋•Œ, ๋ฐ์ดํ„ฐ ๋ˆ„์ถœ(data leakage)์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ๋™์ผํ•œ ๊ทธ๋ฃน / ์žฅ๋ฉด์˜ ์˜์ƒ ํด๋ฆฝ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์‚ฌ์šฉํ•˜๋Š” ํ•˜์œ„ ์ง‘ํ•ฉ์€ ์ด๋Ÿฌํ•œ ์ •๋ณด๋ฅผ ๊ณ ๋ คํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ ๋‹ค์Œ์œผ๋กœ, ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์กด์žฌํ•˜๋Š” ๋ผ๋ฒจ์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ, ๋ชจ๋ธ์„ ์ดˆ๊ธฐํ™”ํ•  ๋•Œ ๋„์›€์ด ๋  ๋”•์…”๋„ˆ๋ฆฌ(dictionary data type)๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

  • label2id: ํด๋ž˜์Šค ์ด๋ฆ„์„ ์ •์ˆ˜์— ๋งคํ•‘ํ•ฉ๋‹ˆ๋‹ค.
  • id2label: ์ •์ˆ˜๋ฅผ ํด๋ž˜์Šค ์ด๋ฆ„์— ๋งคํ•‘ํ•ฉ๋‹ˆ๋‹ค.
>>> class_labels = sorted({str(path).split("/")[2] for path in all_video_file_paths})
>>> label2id = {label: i for i, label in enumerate(class_labels)}
>>> id2label = {i: label for label, i in label2id.items()}

>>> print(f"Unique classes: {list(label2id.keys())}.")

# Unique classes: ['ApplyEyeMakeup', 'ApplyLipstick', 'Archery', 'BabyCrawling', 'BalanceBeam', 'BandMarching', 'BaseballPitch', 'Basketball', 'BasketballDunk', 'BenchPress'].

์ด ๋ฐ์ดํ„ฐ ์„ธํŠธ์—๋Š” ์ด 10๊ฐœ์˜ ๊ณ ์œ ํ•œ ํด๋ž˜์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ํด๋ž˜์Šค๋งˆ๋‹ค 30๊ฐœ์˜ ์˜์ƒ์ด ํ›ˆ๋ จ ์„ธํŠธ์— ์žˆ์Šต๋‹ˆ๋‹ค

๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ [[load-a-model-to-fine-tune]]

์‚ฌ์ „ ํ›ˆ๋ จ๋œ ์ฒดํฌํฌ์ธํŠธ์™€ ์ฒดํฌํฌ์ธํŠธ์— ์—ฐ๊ด€๋œ ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜์ƒ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ์ธ์Šคํ„ด์Šคํ™”ํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ์˜ ์ธ์ฝ”๋”์—๋Š” ๋ฏธ๋ฆฌ ํ•™์Šต๋œ ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ์ œ๊ณต๋˜๋ฉฐ, ๋ถ„๋ฅ˜ ํ—ค๋“œ(๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด)๋Š” ๋ฌด์ž‘์œ„๋กœ ์ดˆ๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ž‘์„ฑํ•  ๋•Œ๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๊ฐ€ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

>>> from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification

>>> model_ckpt = "MCG-NJU/videomae-base"
>>> image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)
>>> model = VideoMAEForVideoClassification.from_pretrained(
...     model_ckpt,
...     label2id=label2id,
...     id2label=id2label,
...     ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
... )

๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ค๋Š” ๋™์•ˆ, ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒฝ๊ณ ๋ฅผ ๋งˆ์ฃผ์น  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

Some weights of the model checkpoint at MCG-NJU/videomae-base were not used when initializing VideoMAEForVideoClassification: [..., 'decoder.decoder_layers.1.attention.output.dense.bias', 'decoder.decoder_layers.2.attention.attention.key.weight']
- This IS expected if you are initializing VideoMAEForVideoClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VideoMAEForVideoClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at MCG-NJU/videomae-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

์œ„ ๊ฒฝ๊ณ ๋Š” ์šฐ๋ฆฌ๊ฐ€ ์ผ๋ถ€ ๊ฐ€์ค‘์น˜(์˜ˆ: classifier ์ธต์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ)๋ฅผ ๋ฒ„๋ฆฌ๊ณ  ์ƒˆ๋กœ์šด classifier ์ธต์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ์„ ๋ฌด์ž‘์œ„๋กœ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  ์žˆ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ๋ ค์ค๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ์—๋Š” ๋ฏธ๋ฆฌ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๊ฐ€ ์—†๋Š” ์ƒˆ๋กœ์šด ํ—ค๋“œ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ  ์žˆ์œผ๋ฏ€๋กœ, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•˜๊ธฐ ์ „์— ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋ผ๊ณ  ๊ฒฝ๊ณ ๋ฅผ ๋ณด๋‚ด๋Š” ๊ฒƒ์€ ๋‹น์—ฐํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ด์ œ ์šฐ๋ฆฌ๋Š” ์ด ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•  ์˜ˆ์ •์ž…๋‹ˆ๋‹ค.

์ฐธ๊ณ  ์ด ์ฒดํฌํฌ์ธํŠธ๋Š” ๋„๋ฉ”์ธ์ด ๋งŽ์ด ์ค‘์ฒฉ๋œ ์œ ์‚ฌํ•œ ๋‹ค์šด์ŠคํŠธ๋ฆผ ์ž‘์—…์— ๋Œ€ํ•ด ๋ฏธ์„ธ ์กฐ์ •ํ•˜์—ฌ ์–ป์€ ์ฒดํฌํฌ์ธํŠธ์ด๋ฏ€๋กœ ์ด ์ž‘์—…์—์„œ ๋” ๋‚˜์€ ์„ฑ๋Šฅ์„ ๋ณด์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. MCG-NJU/videomae-base-finetuned-kinetics ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์—ฌ ์–ป์€ ์ฒดํฌํฌ์ธํŠธ๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ›ˆ๋ จ์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ ์ค€๋น„ํ•˜๊ธฐ[[prepare-the-datasets-for-training]]

์˜์ƒ ์ „์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด PyTorchVideo ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ํ™œ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํ•„์š”ํ•œ ์ข…์†์„ฑ์„ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•˜์„ธ์š”.

>>> import pytorchvideo.data

>>> from pytorchvideo.transforms import (
...     ApplyTransformToKey,
...     Normalize,
...     RandomShortSideScale,
...     RemoveKey,
...     ShortSideScale,
...     UniformTemporalSubsample,
... )

>>> from torchvision.transforms import (
...     Compose,
...     Lambda,
...     RandomCrop,
...     RandomHorizontalFlip,
...     Resize,
... )

ํ•™์Šต ๋ฐ์ดํ„ฐ ์„ธํŠธ ๋ณ€ํ™˜์—๋Š” '๊ท ์ผํ•œ ์‹œ๊ฐ„ ์ƒ˜ํ”Œ๋ง(uniform temporal subsampling)', 'ํ”ฝ์…€ ์ •๊ทœํ™”(pixel normalization)', '๋žœ๋ค ์ž˜๋ผ๋‚ด๊ธฐ(random cropping)' ๋ฐ '๋žœ๋ค ์ˆ˜ํ‰ ๋’ค์ง‘๊ธฐ(random horizontal flipping)'์˜ ์กฐํ•ฉ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ๊ฒ€์ฆ ๋ฐ ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ์„ธํŠธ ๋ณ€ํ™˜์—๋Š” '๋žœ๋ค ์ž˜๋ผ๋‚ด๊ธฐ'์™€ '๋žœ๋ค ๋’ค์ง‘๊ธฐ'๋ฅผ ์ œ์™ธํ•œ ๋™์ผํ•œ ๋ณ€ํ™˜ ์ฒด์ธ์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ณ€ํ™˜์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด PyTorchVideo ๊ณต์‹ ๋ฌธ์„œ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.

์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ ๊ด€๋ จ๋œ ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์Œ ์ •๋ณด๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

  • ์˜์ƒ ํ”„๋ ˆ์ž„ ํ”ฝ์…€์„ ์ •๊ทœํ™”ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ์ด๋ฏธ์ง€ ํ‰๊ท ๊ณผ ํ‘œ์ค€ ํŽธ์ฐจ
  • ์˜์ƒ ํ”„๋ ˆ์ž„์ด ์กฐ์ •๋  ๊ณต๊ฐ„ ํ•ด์ƒ๋„

๋จผ์ €, ๋ช‡ ๊ฐ€์ง€ ์ƒ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

>>> mean = image_processor.image_mean
>>> std = image_processor.image_std
>>> if "shortest_edge" in image_processor.size:
...     height = width = image_processor.size["shortest_edge"]
>>> else:
...     height = image_processor.size["height"]
...     width = image_processor.size["width"]
>>> resize_to = (height, width)

>>> num_frames_to_sample = model.config.num_frames
>>> sample_rate = 4
>>> fps = 30
>>> clip_duration = num_frames_to_sample * sample_rate / fps

์ด์ œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ํŠนํ™”๋œ ์ „์ฒ˜๋ฆฌ(transform)๊ณผ ๋ฐ์ดํ„ฐ ์„ธํŠธ ์ž์ฒด๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ๋จผ์ € ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋กœ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค:

>>> train_transform = Compose(
...     [
...         ApplyTransformToKey(
...             key="video",
...             transform=Compose(
...                 [
...                     UniformTemporalSubsample(num_frames_to_sample),
...                     Lambda(lambda x: x / 255.0),
...                     Normalize(mean, std),
...                     RandomShortSideScale(min_size=256, max_size=320),
...                     RandomCrop(resize_to),
...                     RandomHorizontalFlip(p=0.5),
...                 ]
...             ),
...         ),
...     ]
... )

>>> train_dataset = pytorchvideo.data.Ucf101(
...     data_path=os.path.join(dataset_root_path, "train"),
...     clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
...     decode_audio=False,
...     transform=train_transform,
... )

๊ฐ™์€ ๋ฐฉ์‹์˜ ์ž‘์—… ํ๋ฆ„์„ ๊ฒ€์ฆ๊ณผ ํ‰๊ฐ€ ์„ธํŠธ์—๋„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> val_transform = Compose(
...     [
...         ApplyTransformToKey(
...             key="video",
...             transform=Compose(
...                 [
...                     UniformTemporalSubsample(num_frames_to_sample),
...                     Lambda(lambda x: x / 255.0),
...                     Normalize(mean, std),
...                     Resize(resize_to),
...                 ]
...             ),
...         ),
...     ]
... )

>>> val_dataset = pytorchvideo.data.Ucf101(
...     data_path=os.path.join(dataset_root_path, "val"),
...     clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
...     decode_audio=False,
...     transform=val_transform,
... )

>>> test_dataset = pytorchvideo.data.Ucf101(
...     data_path=os.path.join(dataset_root_path, "test"),
...     clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
...     decode_audio=False,
...     transform=val_transform,
... )

์ฐธ๊ณ : ์œ„์˜ ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํŒŒ์ดํ”„๋ผ์ธ์€ ๊ณต์‹ ํŒŒ์ดํ† ์น˜ ์˜ˆ์ œ์—์„œ ๊ฐ€์ ธ์˜จ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” UCF-101 ๋ฐ์ดํ„ฐ์…‹์— ๋งž๊ฒŒ pytorchvideo.data.Ucf101() ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚ด๋ถ€์ ์œผ๋กœ ์ด ํ•จ์ˆ˜๋Š” pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset ๊ฐ์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. LabeledVideoDataset ํด๋ž˜์Šค๋Š” PyTorchVideo ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋ชจ๋“  ์˜์ƒ ๊ด€๋ จ ์ž‘์—…์˜ ๊ธฐ๋ณธ ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ PyTorchVideo์—์„œ ๋ฏธ๋ฆฌ ์ œ๊ณตํ•˜์ง€ ์•Š๋Š” ์‚ฌ์šฉ์ž ์ง€์ • ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด, ์ด ํด๋ž˜์Šค๋ฅผ ์ ์ ˆํ•˜๊ฒŒ ํ™•์žฅํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. ๋” ์ž์„ธํ•œ ์‚ฌํ•ญ์ด ์•Œ๊ณ  ์‹ถ๋‹ค๋ฉด data API ๋ฌธ์„œ ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”. ๋˜ํ•œ ์œ„์˜ ์˜ˆ์‹œ์™€ ์œ ์‚ฌํ•œ ๊ตฌ์กฐ๋ฅผ ๊ฐ–๋Š” ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋‹ค๋ฉด, pytorchvideo.data.Ucf101() ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐ ๋ฌธ์ œ๊ฐ€ ์—†์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์˜์ƒ์˜ ๊ฐœ์ˆ˜๋ฅผ ์•Œ๊ธฐ ์œ„ํ•ด num_videos ์ธ์ˆ˜์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> print(train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos)
# (300, 30, 75)

๋” ๋‚˜์€ ๋””๋ฒ„๊น…์„ ์œ„ํ•ด ์ „์ฒ˜๋ฆฌ ์˜์ƒ ์‹œ๊ฐํ™”ํ•˜๊ธฐ[[visualize-the-preprocessed-video-for-better-debugging]]

>>> import imageio
>>> import numpy as np
>>> from IPython.display import Image

>>> def unnormalize_img(img):
...     """Un-normalizes the image pixels."""
...     img = (img * std) + mean
...     img = (img * 255).astype("uint8")
...     return img.clip(0, 255)

>>> def create_gif(video_tensor, filename="sample.gif"):
...     """Prepares a GIF from a video tensor.
...     
...     The video tensor is expected to have the following shape:
...     (num_frames, num_channels, height, width).
...     """
...     frames = []
...     for video_frame in video_tensor:
...         frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy())
...         frames.append(frame_unnormalized)
...     kargs = {"duration": 0.25}
...     imageio.mimsave(filename, frames, "GIF", **kargs)
...     return filename

>>> def display_gif(video_tensor, gif_name="sample.gif"):
...     """Prepares and displays a GIF from a video tensor."""
...     video_tensor = video_tensor.permute(1, 0, 2, 3)
...     gif_filename = create_gif(video_tensor, gif_name)
...     return Image(filename=gif_filename)

>>> sample_video = next(iter(train_dataset))
>>> video_tensor = sample_video["video"]
>>> display_gif(video_tensor)
Person playing basketball

๋ชจ๋ธ ํ›ˆ๋ จํ•˜๊ธฐ[[train-the-model]]

๐Ÿค— Transformers์˜ Trainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ›ˆ๋ จ์‹œ์ผœ๋ณด์„ธ์š”. Trainer๋ฅผ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๋ ค๋ฉด ํ›ˆ๋ จ ์„ค์ •๊ณผ ํ‰๊ฐ€ ์ง€ํ‘œ๋ฅผ ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๊ฒƒ์€ TrainingArguments์ž…๋‹ˆ๋‹ค. ์ด ํด๋ž˜์Šค๋Š” ํ›ˆ๋ จ์„ ๊ตฌ์„ฑํ•˜๋Š” ๋ชจ๋“  ์†์„ฑ์„ ํฌํ•จํ•˜๋ฉฐ, ํ›ˆ๋ จ ์ค‘ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•  ์ถœ๋ ฅ ํด๋” ์ด๋ฆ„์„ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๐Ÿค— Hub์˜ ๋ชจ๋ธ ์ €์žฅ์†Œ์˜ ๋ชจ๋“  ์ •๋ณด๋ฅผ ๋™๊ธฐํ™”ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค.

๋Œ€๋ถ€๋ถ„์˜ ํ›ˆ๋ จ ์ธ์ˆ˜๋Š” ๋”ฐ๋กœ ์„ค๋ช…ํ•  ํ•„์š”๋Š” ์—†์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์—ฌ๊ธฐ์—์„œ ์ค‘์š”ํ•œ ์ธ์ˆ˜๋Š” remove_unused_columns=False ์ž…๋‹ˆ๋‹ค. ์ด ์ธ์ž๋Š” ๋ชจ๋ธ์˜ ํ˜ธ์ถœ ํ•จ์ˆ˜์—์„œ ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š” ๋ชจ๋“  ์†์„ฑ ์—ด(columns)์„ ์‚ญ์ œํ•ฉ๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์€ ์ผ๋ฐ˜์ ์œผ๋กœ True์ž…๋‹ˆ๋‹ค. ์ด๋Š” ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š” ๊ธฐ๋Šฅ ์—ด์„ ์‚ญ์ œํ•˜๋Š” ๊ฒƒ์ด ์ด์ƒ์ ์ด๋ฉฐ, ์ž…๋ ฅ์„ ๋ชจ๋ธ์˜ ํ˜ธ์ถœ ํ•จ์ˆ˜๋กœ ํ’€๊ธฐ(unpack)๊ฐ€ ์‰ฌ์›Œ์ง€๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด ๊ฒฝ์šฐ์—๋Š” pixel_values(๋ชจ๋ธ์˜ ์ž…๋ ฅ์œผ๋กœ ํ•„์ˆ˜์ ์ธ ํ‚ค)๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š” ๊ธฐ๋Šฅ('video'๊ฐ€ ํŠนํžˆ ๊ทธ๋ ‡์Šต๋‹ˆ๋‹ค)์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ remove_unused_columns์„ False๋กœ ์„ค์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

>>> from transformers import TrainingArguments, Trainer

>>> model_name = model_ckpt.split("/")[-1]
>>> new_model_name = f"{model_name}-finetuned-ucf101-subset"
>>> num_epochs = 4

>>> args = TrainingArguments(
...     new_model_name,
...     remove_unused_columns=False,
...     evaluation_strategy="epoch",
...     save_strategy="epoch",
...     learning_rate=5e-5,
...     per_device_train_batch_size=batch_size,
...     per_device_eval_batch_size=batch_size,
...     warmup_ratio=0.1,
...     logging_steps=10,
...     load_best_model_at_end=True,
...     metric_for_best_model="accuracy",
...     push_to_hub=True,
...     max_steps=(train_dataset.num_videos // batch_size) * num_epochs,
... )

pytorchvideo.data.Ucf101() ํ•จ์ˆ˜๋กœ ๋ฐ˜ํ™˜๋˜๋Š” ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” __len__ ๋ฉ”์†Œ๋“œ๊ฐ€ ์ด์‹๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, TrainingArguments๋ฅผ ์ธ์Šคํ„ด์Šคํ™”ํ•  ๋•Œ max_steps๋ฅผ ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ์œผ๋กœ, ํ‰๊ฐ€์ง€ํ‘œ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ , ์˜ˆ์ธก๊ฐ’์—์„œ ํ‰๊ฐ€์ง€ํ‘œ๋ฅผ ๊ณ„์‚ฐํ•  ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ํ•„์š”ํ•œ ์ „์ฒ˜๋ฆฌ ์ž‘์—…์€ ์˜ˆ์ธก๋œ ๋กœ์ง“(logits)์— argmax ๊ฐ’์„ ์ทจํ•˜๋Š” ๊ฒƒ๋ฟ์ž…๋‹ˆ๋‹ค:

import evaluate

metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

ํ‰๊ฐ€์— ๋Œ€ํ•œ ์ฐธ๊ณ ์‚ฌํ•ญ:

VideoMAE ๋…ผ๋ฌธ์—์„œ ์ €์ž๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ‰๊ฐ€ ์ „๋žต์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ ์˜์ƒ์—์„œ ์—ฌ๋Ÿฌ ํด๋ฆฝ์„ ์„ ํƒํ•˜๊ณ  ๊ทธ ํด๋ฆฝ์— ๋‹ค์–‘ํ•œ ํฌ๋กญ์„ ์ ์šฉํ•˜์—ฌ ์ง‘๊ณ„ ์ ์ˆ˜๋ฅผ ๋ณด๊ณ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด๋ฒˆ ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ๊ฐ„๋‹จํ•จ๊ณผ ๊ฐ„๊ฒฐํ•จ์„ ์œ„ํ•ด ํ•ด๋‹น ์ „๋žต์„ ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ, ์˜ˆ์ œ๋ฅผ ๋ฌถ์–ด์„œ ๋ฐฐ์น˜๋ฅผ ํ˜•์„ฑํ•˜๋Š” collate_fn์„ ์ •์˜ํ•ด์•ผํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ๋ฐฐ์น˜๋Š” pixel_values์™€ labels๋ผ๋Š” 2๊ฐœ์˜ ํ‚ค๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

>>> def collate_fn(examples):
...     # permute to (num_frames, num_channels, height, width)
...     pixel_values = torch.stack(
...         [example["video"].permute(1, 0, 2, 3) for example in examples]
...     )
...     labels = torch.tensor([example["label"] for example in examples])
...     return {"pixel_values": pixel_values, "labels": labels}

๊ทธ๋Ÿฐ ๋‹ค์Œ ์ด ๋ชจ๋“  ๊ฒƒ์„ ๋ฐ์ดํ„ฐ ์„ธํŠธ์™€ ํ•จ๊ป˜ Trainer์— ์ „๋‹ฌํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค:

>>> trainer = Trainer(
...     model,
...     args,
...     train_dataset=train_dataset,
...     eval_dataset=val_dataset,
...     tokenizer=image_processor,
...     compute_metrics=compute_metrics,
...     data_collator=collate_fn,
... )

๋ฐ์ดํ„ฐ๋ฅผ ์ด๋ฏธ ์ฒ˜๋ฆฌํ–ˆ๋Š”๋ฐ๋„ ๋ถˆ๊ตฌํ•˜๊ณ  image_processor๋ฅผ ํ† ํฌ๋‚˜์ด์ € ์ธ์ˆ˜๋กœ ๋„ฃ์€ ์ด์œ ๋Š” JSON์œผ๋กœ ์ €์žฅ๋˜๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ ๊ตฌ์„ฑ ํŒŒ์ผ์ด Hub์˜ ์ €์žฅ์†Œ์— ์—…๋กœ๋“œ๋˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•จ์ž…๋‹ˆ๋‹ค.

train ๋ฉ”์†Œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜์„ธ์š”:

>>> train_results = trainer.train()

ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด, ๋ชจ๋ธ์„ [~transformers.Trainer.push_to_hub] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ—ˆ๋ธŒ์— ๊ณต์œ ํ•˜์—ฌ ๋ˆ„๊ตฌ๋‚˜ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค:

>>> trainer.push_to_hub()

์ถ”๋ก ํ•˜๊ธฐ[[inference]]

์ข‹์Šต๋‹ˆ๋‹ค. ์ด์ œ ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์ถ”๋ก ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์˜์ƒ์„ ๋ถˆ๋Ÿฌ์˜ค์„ธ์š”:

>>> sample_test_video = next(iter(test_dataset))
Teams playing basketball

๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•˜๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ pipeline์—์„œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ๋กœ ์˜์ƒ ๋ถ„๋ฅ˜๋ฅผ ํ•˜๊ธฐ ์œ„ํ•ด pipeline์„ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๊ณ  ์˜์ƒ์„ ์ „๋‹ฌํ•˜์„ธ์š”:

>>> from transformers import pipeline

>>> video_cls = pipeline(model="my_awesome_video_cls_model")
>>> video_cls("https://huggingface.co/datasets/sayakpaul/ucf101-subset/resolve/main/v_BasketballDunk_g14_c06.avi")
[{'score': 0.9272987842559814, 'label': 'BasketballDunk'},
 {'score': 0.017777055501937866, 'label': 'BabyCrawling'},
 {'score': 0.01663011871278286, 'label': 'BalanceBeam'},
 {'score': 0.009560945443809032, 'label': 'BandMarching'},
 {'score': 0.0068979403004050255, 'label': 'BaseballPitch'}]

๋งŒ์•ฝ ์›ํ•œ๋‹ค๋ฉด ์ˆ˜๋™์œผ๋กœ pipeline์˜ ๊ฒฐ๊ณผ๋ฅผ ์žฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> def run_inference(model, video):
...     # (num_frames, num_channels, height, width)
...     perumuted_sample_test_video = video.permute(1, 0, 2, 3)
...     inputs = {
...         "pixel_values": perumuted_sample_test_video.unsqueeze(0),
...         "labels": torch.tensor(
...             [sample_test_video["label"]]
...         ),  # this can be skipped if you don't have labels available.
...     }

...     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...     inputs = {k: v.to(device) for k, v in inputs.items()}
...     model = model.to(device)

...     # forward pass
...     with torch.no_grad():
...         outputs = model(**inputs)
...         logits = outputs.logits

...     return logits

๋ชจ๋ธ์— ์ž…๋ ฅ๊ฐ’์„ ๋„ฃ๊ณ  logits์„ ๋ฐ˜ํ™˜๋ฐ›์œผ์„ธ์š”:

>>> logits = run_inference(trained_model, sample_test_video["video"])

logits์„ ๋””์ฝ”๋”ฉํ•˜๋ฉด, ์šฐ๋ฆฌ๋Š” ๋‹ค์Œ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
# Predicted class: BasketballDunk