Seeker38's picture
Create README.md
83aec7a verified
|
raw
history blame
3.03 kB
metadata
language:
  - vi
pretty_name: Well-known Vietnamese people and corresponding abstracts in Wikipedia
source_datasets:
  - original
size_categories:
  - 1K<n<10K
tags:
  - wikipedia
  - images
  - text
  - LM
dataset_info:
  features:
    - name: image
      dtype: image
    - name: title
      dtype: string
    - name: text
      dtype: string
license: mit
datasets:
  - Seeker38/vietnamese_face_wiki
metrics:
  - bleu

Image Captioning - Fine Tune ViT-PhoBERT Model

This is ViT-PhoBERT fine tune Model on vietnamese_face_wiki dataset

How to use

import needed library

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoImageProcessor, AutoTokenizer, VisionEncoderDecoderModel

load the dataset you need

from datasets import load_dataset

dataset = load_dataset("Seeker38/augmented_vi_face_wiki", split="train")

load the model

from transformers import AutoImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained("Seeker38/ViT_PhoBert_face_vi_wiki")
phobert_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2", add_special_tokens=True)

if phobert_tokenizer.pad_token is None:
    phobert_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

contruct caption generate method

def generate_caption(model, dataset, tokenizer, device, num_images=20, max_length=50):
    model.eval()
    
    sampled_indices = random.sample(range(len(dataset)), num_images)
    sampled_images = [dataset[idx]['image'] for idx in sampled_indices]
    pixel_values_list = []
    
    for image in sampled_images:
        image = image.resize((224, 224))
        image = np.array(image, dtype=np.uint8)
        image = torch.tensor(np.moveaxis(image, -1, 0), dtype=torch.float32)
        pixel_values_list.append(image)

    pixel_values = torch.stack(pixel_values_list).to(device)
    
    with torch.no_grad():
        outputs = model.generate(pixel_values, num_beams=10, max_length=max_length, early_stopping=True, length_penalty=1.0)
    
    decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Display the images and their captions in a single column
    fig, axs = plt.subplots(num_images, 2, figsize=(15, 5 * num_images))
    
    for i, (image, caption) in enumerate(zip(sampled_images, decoded_preds)):
        axs[i, 0].imshow(image)
        axs[i, 0].axis('off')
        axs[i, 1].text(0, 0.5, caption, wrap=True, fontsize=12)
        axs[i, 1].axis('off')
    
    plt.tight_layout()
    
    # Save the plot to a local file
    output_file = "/kaggle/working/generated_captions.png"
    plt.savefig(output_file)
    plt.show()

    print(f"Plot saved as {output_file}")

Run and enjoy

generate_caption(model, dataset, phobert_tokenizer, device,5,70)