|
<!--Copyright 2023 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
|
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be |
|
rendered properly in your Markdown viewer. |
|
|
|
--> |
|
|
|
|
|
# Image captioning |
|
|
|
[[open-in-colab]] |
|
|
|
Image captioning is the task of predicting a caption for a given image. Common real world applications of it include |
|
aiding visually impaired people that can help them navigate through different situations. Therefore, image captioning |
|
helps to improve content accessibility for people by describing images to them. |
|
|
|
This guide will show you how to: |
|
|
|
* Fine-tune an image captioning model. |
|
* Use the fine-tuned model for inference. |
|
|
|
Before you begin, make sure you have all the necessary libraries installed: |
|
|
|
```bash |
|
pip install transformers datasets evaluate -q |
|
pip install jiwer -q |
|
``` |
|
|
|
We encourage you to log in to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to log in: |
|
|
|
|
|
```python |
|
from huggingface_hub import notebook_login |
|
|
|
notebook_login() |
|
``` |
|
|
|
## Load the Pokémon BLIP captions dataset |
|
|
|
Use the 🤗 Dataset library to load a dataset that consists of {image-caption} pairs. To create your own image captioning dataset |
|
in PyTorch, you can follow [this notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/GIT/Fine_tune_GIT_on_an_image_captioning_dataset.ipynb). |
|
|
|
|
|
```python |
|
from datasets import load_dataset |
|
|
|
ds = load_dataset("lambdalabs/pokemon-blip-captions") |
|
ds |
|
``` |
|
```bash |
|
DatasetDict({ |
|
train: Dataset({ |
|
features: ['image', 'text'], |
|
num_rows: 833 |
|
}) |
|
}) |
|
``` |
|
|
|
The dataset has two features, `image` and `text`. |
|
|
|
<Tip> |
|
|
|
Many image captioning datasets contain multiple captions per image. In those cases, a common strategy is to randomly sample a caption amongst the available ones during training. |
|
|
|
</Tip> |
|
|
|
Split the dataset’s train split into a train and test set with the [~datasets.Dataset.train_test_split] method: |
|
|
|
|
|
```python |
|
ds = ds["train"].train_test_split(test_size=0.1) |
|
train_ds = ds["train"] |
|
test_ds = ds["test"] |
|
``` |
|
|
|
Let's visualize a couple of samples from the training set. |
|
|
|
|
|
```python |
|
from textwrap import wrap |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
def plot_images(images, captions): |
|
plt.figure(figsize=(20, 20)) |
|
for i in range(len(images)): |
|
ax = plt.subplot(1, len(images), i + 1) |
|
caption = captions[i] |
|
caption = "\n".join(wrap(caption, 12)) |
|
plt.title(caption) |
|
plt.imshow(images[i]) |
|
plt.axis("off") |
|
|
|
|
|
sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)] |
|
sample_captions = [train_ds[i]["text"] for i in range(5)] |
|
plot_images(sample_images_to_visualize, sample_captions) |
|
``` |
|
|
|
<div class="flex justify-center"> |
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_training_images_image_cap.png" alt="Sample training images"/> |
|
</div> |
|
|
|
## Preprocess the dataset |
|
|
|
Since the dataset has two modalities (image and text), the pre-processing pipeline will preprocess images and the captions. |
|
|
|
To do so, load the processor class associated with the model you are about to fine-tune. |
|
|
|
```python |
|
from transformers import AutoProcessor |
|
|
|
checkpoint = "microsoft/git-base" |
|
processor = AutoProcessor.from_pretrained(checkpoint) |
|
``` |
|
|
|
The processor will internally pre-process the image (which includes resizing, and pixel scaling) and tokenize the caption. |
|
|
|
```python |
|
def transforms(example_batch): |
|
images = [x for x in example_batch["image"]] |
|
captions = [x for x in example_batch["text"]] |
|
inputs = processor(images=images, text=captions, padding="max_length") |
|
inputs.update({"labels": inputs["input_ids"]}) |
|
return inputs |
|
|
|
|
|
train_ds.set_transform(transforms) |
|
test_ds.set_transform(transforms) |
|
``` |
|
|
|
With the dataset ready, you can now set up the model for fine-tuning. |
|
|
|
## Load a base model |
|
|
|
Load the ["microsoft/git-base"](https://huggingface.co/microsoft/git-base) into a [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) object. |
|
|
|
|
|
```python |
|
from transformers import AutoModelForCausalLM |
|
|
|
model = AutoModelForCausalLM.from_pretrained(checkpoint) |
|
``` |
|
|
|
## Evaluate |
|
|
|
Image captioning models are typically evaluated with the [Rouge Score](https://huggingface.co/spaces/evaluate-metric/rouge) or [Word Error Rate](https://huggingface.co/spaces/evaluate-metric/wer). For this guide, you will use the Word Error Rate (WER). |
|
|
|
We use the 🤗 Evaluate library to do so. For potential limitations and other gotchas of the WER, refer to [this guide](https://huggingface.co/spaces/evaluate-metric/wer). |
|
|
|
|
|
```python |
|
from evaluate import load |
|
import torch |
|
|
|
wer = load("wer") |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
logits, labels = eval_pred |
|
predicted = logits.argmax(-1) |
|
decoded_labels = processor.batch_decode(labels, skip_special_tokens=True) |
|
decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True) |
|
wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels) |
|
return {"wer_score": wer_score} |
|
``` |
|
|
|
## Train! |
|
|
|
Now, you are ready to start fine-tuning the model. You will use the 🤗 [`Trainer`] for this. |
|
|
|
First, define the training arguments using [`TrainingArguments`]. |
|
|
|
|
|
```python |
|
from transformers import TrainingArguments, Trainer |
|
|
|
model_name = checkpoint.split("/")[1] |
|
|
|
training_args = TrainingArguments( |
|
output_dir=f"{model_name}-pokemon", |
|
learning_rate=5e-5, |
|
num_train_epochs=50, |
|
fp16=True, |
|
per_device_train_batch_size=32, |
|
per_device_eval_batch_size=32, |
|
gradient_accumulation_steps=2, |
|
save_total_limit=3, |
|
evaluation_strategy="steps", |
|
eval_steps=50, |
|
save_strategy="steps", |
|
save_steps=50, |
|
logging_steps=50, |
|
remove_unused_columns=False, |
|
push_to_hub=True, |
|
label_names=["labels"], |
|
load_best_model_at_end=True, |
|
) |
|
``` |
|
|
|
Then pass them along with the datasets and the model to 🤗 Trainer. |
|
|
|
```python |
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_ds, |
|
eval_dataset=test_ds, |
|
compute_metrics=compute_metrics, |
|
) |
|
``` |
|
|
|
To start training, simply call [`~Trainer.train`] on the [`Trainer`] object. |
|
|
|
```python |
|
trainer.train() |
|
``` |
|
|
|
You should see the training loss drop smoothly as training progresses. |
|
|
|
Once training is completed, share your model to the Hub with the [`~Trainer.push_to_hub`] method so everyone can use your model: |
|
|
|
|
|
```python |
|
trainer.push_to_hub() |
|
``` |
|
|
|
## Inference |
|
|
|
Take a sample image from `test_ds` to test the model. |
|
|
|
|
|
```python |
|
from PIL import Image |
|
import requests |
|
|
|
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png" |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
image |
|
``` |
|
|
|
<div class="flex justify-center"> |
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/test_image_image_cap.png" alt="Test image"/> |
|
</div> |
|
|
|
Prepare image for the model. |
|
|
|
```python |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
pixel_values = inputs.pixel_values |
|
``` |
|
|
|
Call [`generate`] and decode the predictions. |
|
|
|
```python |
|
generated_ids = model.generate(pixel_values=pixel_values, max_length=50) |
|
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
print(generated_caption) |
|
``` |
|
```bash |
|
a drawing of a pink and blue pokemon |
|
``` |
|
|
|
Looks like the fine-tuned model generated a pretty good caption! |
|
|