|
--- |
|
license: mit |
|
pipeline_tag: image-feature-extraction |
|
tags: |
|
- pretrained |
|
datasets: |
|
- ylecun/mnist |
|
--- |
|
|
|
# Model Card for Llava-mnist |
|
|
|
Llava-mnist is a simple example of Vision and Language model using LLaVA architecture trained on MNIST dataset. |
|
|
|
|
|
You can use this model (just one linear layer vision encoder model) alongside [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). |
|
|
|
## Training Details |
|
|
|
 |
|
|
|
The model was trained on the *chat-style* MNIST dataset, which is structured as follows: |
|
|
|
prompt: “\<image\>What digit is this?” |
|
|
|
output: "The digit is {label}." |
|
|
|
The Llava-MNIST model transforms the digit image into an embedding vector that resides in the same space as the text token embedding. |
|
|
|
The loss function optimized during training is defined as: |
|
|
|
$L(W)= -\log P_W(This digit is \{label\}|\<image\>What digit is this?)$ |
|
|
|
During training, the parameters of the Llama 3.1 model are kept frozen, and only the parameters of the vision encoder (Llava-MNIST) are optimized. |
|
|
|
## How to use |
|
|
|
You can input multi-modal data (vision and text) into the Llama 3.1 model by using the Llava-MNIST model as the vision encoder. |
|
|
|
``` |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
from datasets import load_dataset |
|
from torchvision import transforms |
|
import util |
|
from transformers import AutoModel |
|
|
|
|
|
def build_multi_modal_prompt( |
|
prompt: str, |
|
image: torch.Tensor, |
|
tokenizer: AutoTokenizer, |
|
model: AutoModelForCausalLM, |
|
vision_model: AutoModel, |
|
) -> torch.Tensor: |
|
parts = prompt.split("<image>") |
|
prefix = tokenizer(parts[0]) |
|
suffix = tokenizer(parts[1]) |
|
prefix_embedding = model.get_input_embeddings()(torch.tensor(prefix["input_ids"])) |
|
suffix_embedding = model.get_input_embeddings()(torch.tensor(suffix["input_ids"])) |
|
image_embedding = vision_model(image).to(torch.bfloat16).to(model.device) |
|
multi_modal_embedding = torch.cat( |
|
[prefix_embedding, image_embedding, suffix_embedding], dim=0 |
|
) |
|
return multi_modal_embedding |
|
|
|
|
|
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
|
|
vision_model = AutoModel.from_pretrained( |
|
"speed/llava-mnist", trust_remote_code=True |
|
) |
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>"), |
|
] |
|
|
|
system_prompt = ( |
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|>" |
|
) |
|
user_prompt = "<|start_header_id|>user<|end_header_id|>" |
|
question = "<image>What digit is this?" |
|
assistant_prompt = "<|start_header_id|>assistant<|end_header_id|>" |
|
|
|
prompt = system_prompt + user_prompt + question + assistant_prompt |
|
|
|
ds = load_dataset("ylecun/mnist", split="test") |
|
|
|
|
|
def transform_image(examples): |
|
transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.1307,), (0.3081,)), |
|
transforms.Lambda(lambda x: torch.flatten(x)), |
|
] |
|
) |
|
examples["pixel_values"] = [transform(image) for image in examples["image"]] |
|
|
|
return examples |
|
|
|
ds.set_transform(transform = transform_image) |
|
|
|
|
|
model.eval() |
|
vision_model.eval() |
|
|
|
example = ds[0] |
|
|
|
input_embeded = util.build_multi_modal_prompt( |
|
prompt, example["pixel_values"].unsqueeze(0), tokenizer, model, vision_model |
|
).unsqueeze(0) |
|
response = model.generate( |
|
inputs_embeds=input_embeded, |
|
max_new_tokens=20, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.9, |
|
) |
|
response = response[0] |
|
print("Label:", example["label"]) # Label: 7 |
|
answer = tokenizer.decode(response, skip_special_tokens=True) |
|
print("Answer:", answer) # Answer: The digit is 7. |
|
|
|
``` |
|
|
|
## References |
|
- Liu et al., LLaVA: Large Language and Vision Assistant, https://llava-vl.github.io/ |