kiranr's picture
Create README.md
d181c9e verified

palmyra-vision

usage

from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests
import torch

processor = AutoProcessor.from_pretrained(
    "Writer/palmyra-vision-dummy-weights",
    trust_remote_code=True,
    torch_dtype="auto",
    device_map="auto",
    use_fast=False,
)

model = AutoModelForCausalLM.from_pretrained(
    "Writer/palmyra-vision-dummy-weights",
    trust_remote_code=True,
    torch_dtype="auto",
    device_map="auto",
)


inputs = processor.process(
    images=[
        Image.open(
            requests.get("https://picsum.photos/seed/picsum/200/300", stream=True).raw
        )
    ],
    text="what is this image about?",
)

inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}


output = model.generate_from_batch(
    inputs,
    GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
    tokenizer=processor.tokenizer,
)


generated_tokens = output[0, inputs["input_ids"].size(1) :]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

print(generated_text)