Spaces:
Running
on
Zero
Running
on
Zero
| # Gaze-LLE | |
| <div style="text-align:center;"> | |
| <img src="./assets/the_office.png" height="100"/> | |
| <img src="./assets/MLB_1.gif" height="100"/> | |
| <img src="./assets/succession.png" height="100"/> | |
| <img src="./assets/CBS_2.gif" height="100"/> | |
| </div> | |
| [Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders](https://arxiv.org/abs/2412.09586) \ | |
| [Fiona Ryan](https://fkryan.github.io/), Ajay Bati, [Sangmin Lee](https://sites.google.com/view/sangmin-lee), [Daniel Bolya](https://dbolya.github.io/), [Judy Hoffman](https://faculty.cc.gatech.edu/~judy/)\*, [James M. Rehg](https://rehg.org/)\* | |
| This is the official implementation for Gaze-LLE, a transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose! | |
| <div style="text-align:center;"> | |
| <img src="./assets/gazelle_arch.png" height="200"/> | |
| </div> | |
| ## Installation | |
| Clone this repo, then create the virtual environment. | |
| ``` | |
| conda env create -f environment.yml | |
| conda activate gazelle | |
| pip install -e . | |
| ``` | |
| If your system supports it, consider installing [xformers](https://github.com/facebookresearch/xformers) to speed up attention computation. | |
| ``` | |
| pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118 | |
| ``` | |
| ## Pretrained Models | |
| We provide the following pretrained models for download. | |
| | Name | Backbone type | Backbone name | Training data | Checkpoint | | |
| | ---- | ------------- | ------------- |-------------- | ---------- | | |
| | ```gazelle_dinov2_vitb14``` | DINOv2 ViT-B | ```dinov2_vitb14```| GazeFollow | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitb14.pt) | | |
| | ```gazelle_dinov2_vitl14``` | DINOv2 ViT-L | ```dinov2_vitl14``` | GazeFollow | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14.pt) | | |
| | ```gazelle_dinov2_vitb14_inout``` | DINOv2 ViT-B | ```dinov2_vitb14``` | Gazefollow -> VideoAttentionTarget | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitb14_inout.pt) | | |
| | ```gazelle_large_vitl14_inout``` | DINOv2-ViT-L | ```dinov2_vitl14``` | GazeFollow -> VideoAttentionTarget | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14_inout.pt) | | |
| Note that our Gaze-LLE checkpoints contain only the gaze decoder weights - the DINOv2 backbone weights are downloaded from ```facebookresearch/dinov2``` on PyTorch Hub when the Gaze-LLE model is created in our code. | |
| The GazeFollow-trained models output a spatial heatmap of gaze locations over the scene with values in range ```[0,1]```, where 1 represents the highest probability of the location being a gaze target. The models that are additionally finetuned on VideoAttentionTarget also predict a in/out of frame gaze score in range ```[0,1]``` where 1 represents the person's gaze target being in the frame. | |
| ### PyTorch Hub | |
| The models are also available on PyTorch Hub for easy use without installing from source. | |
| ``` | |
| model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitb14') | |
| model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14') | |
| model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitb14_inout') | |
| model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout') | |
| ``` | |
| ## Usage | |
| ### Colab Demo Notebook | |
| Check out our [Demo Notebook](https://colab.research.google.com/drive/1TSoyFvNs1-au9kjOZN_fo5ebdzngSPDq?usp=sharing) on Google Colab for how to detect gaze for all people in an image. | |
| ### Gaze Prediction | |
| Gaze-LLE is set up for multi-person inference (e.g. for a single image, GazeLLE encodes the scene only once and then uses the features to predict the gaze of multiple people in the image). The input is a batch of image tensors and a list of bounding boxes for each image representing the heads of the people to predict gaze for in each image. The bounding boxes are tuples of form ```(xmin, ymin, xmax, ymax)``` and are in ```[0,1]``` normalized image coordinates. Below we show how to perform inference for a single person in a single image. | |
| ``` | |
| from PIL import Image | |
| import torch | |
| from gazelle.model import get_gazelle_model | |
| model, transform = get_gazelle_model("gazelle_dinov2_vitl14_inout") | |
| model.load_gazelle_state_dict(torch.load("/path/to/checkpoint.pt", weights_only=True)) | |
| model.eval() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| image = Image.open("path/to/image.png").convert("RGB") | |
| input = { | |
| "images": transform(image).unsqueeze(dim=0).to(device), # tensor of shape [1, 3, 448, 448] | |
| "bboxes": [[(0.1, 0.2, 0.5, 0.7)]] # list of lists of bbox tuples | |
| } | |
| with torch.no_grad(): | |
| output = model(input) | |
| predicted_heatmap = output["heatmap"][0][0] # access prediction for first person in first image. Tensor of size [64, 64] | |
| predicted_inout = output["inout"][0][0] # in/out of frame score (1 = in frame) (output["inout"] will be None for non-inout models) | |
| ``` | |
| We empirically find that Gaze-LLE is effective without a bounding box input for scenes with just one person. However, providing a bounding box can improve results, and is necessary for scenes with multiple people to specify which person's gaze to estimate. To inference without a bounding box, use None in place of a bounding box tuple in the bbox list (e.g. ```input["bboxes"] = [[None]]``` in the example above). | |
| We also provide a function to visualize the predicted heatmap for an image. | |
| ``` | |
| import matplotlib.pyplot as plt | |
| from gazelle.utils import visualize_heatmap | |
| viz = visualize_heatmap(image, predicted_heatmap) | |
| plt.imshow(viz) | |
| plt.show() | |
| ``` | |
| ## Evaluate | |
| We provide evaluation scripts for GazeFollow and VideoAttentionTarget below to reproduce our results from our checkpoints. | |
| ### GazeFollow | |
| Download the GazeFollow dataset [here](https://github.com/ejcgt/attention-target-detection?tab=readme-ov-file#dataset). We provide a preprocessing script ```data_prep/preprocess_gazefollow.py```, which preprocesses and compiles the annotations into a JSON file for each split within the dataset folder. Run the preprocessing script as | |
| ``` | |
| python data_prep/preprocess_gazefollow.py --data_path /path/to/gazefollow/data_new | |
| ``` | |
| Download the pretrained model checkpoints above and use ```--model_name``` and ```ckpt_path``` to specify the model type and checkpoint for evaluation. | |
| ``` | |
| python scripts/eval_gazefollow.py | |
| --data_path /path/to/gazefollow/data_new \ | |
| --model_name gazelle_dinov2_vitl14 \ | |
| --ckpt_path /path/to/checkpoint.pt \ | |
| --batch_size 128 | |
| ``` | |
| ### VideoAttentionTarget | |
| Download the VideoAttentionTarget dataset [here](https://github.com/ejcgt/attention-target-detection?tab=readme-ov-file#dataset-1). We provide a preprocessing script ```data_prep/preprocess_vat.py```, which preprocesses and compiles the annotations into a JSON file for each split within the dataset folder. Run the preprocessing script as | |
| ``` | |
| python data_prep/preprocess_gazefollow.py --data_path /path/to/videoattentiontarget | |
| ``` | |
| Download the pretrained model checkpoints above and use ```--model_name``` and ```ckpt_path``` to specify the model type and checkpoint for evaluation. | |
| ``` | |
| python scripts/eval_vat.py | |
| --data_path /path/to/videoattentiontarget \ | |
| --model_name gazelle_dinov2_vitl14_inout \ | |
| --ckpt_path /path/to/checkpoint.pt \ | |
| --batch_size 64 | |
| ``` | |
| ## Citation | |
| ``` | |
| @article{ryan2024gazelle, | |
| author = {Ryan, Fiona and Bati, Ajay and Lee, Sangmin and Bolya, Daniel and Hoffman, Judy and Rehg, James M}, | |
| title = {Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders}, | |
| journal = {arXiv preprint arXiv:2412.09586}, | |
| year = {2024}, | |
| } | |
| ``` | |
| ## References | |
| - Our models are built on top of pretrained DINOv2 models from PyTorch Hub ([Github repo](https://github.com/facebookresearch/dinov2)). | |
| - Our GazeFollow and VideoAttentionTarget preprocessing code is based on [Detecting Attended Targets in Video](https://github.com/ejcgt/attention-target-detection). | |
| - We use [PyTorch Image Models (timm)](https://github.com/huggingface/pytorch-image-models) for our transformer implementation. | |
| - We use [xFormers](https://github.com/facebookresearch/xformers) for efficient multi-head attention. | |