Batch inference

#1
by SeeingFarther - opened

Could you provide information on available APIs or methods for performing batch inference, or is it not supported?

Efficient-Large-Model org

Sure, I just updated the doc and usage.

Efficient-Large-Model org
from transformers import AutoProcessor, AutoModel

model_path = "Efficient-Large-Model/NVILA-Lite-2B-hf-preview"
model_path = "./NVILA-Lite-2B-hf-preview"
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
# important: set model to eval mode, otherwise the model will be in training mode and will pad to right.
model.eval()

gpt_conv1 = [{
    "role": "user",
    "content": [
        {"type": "image", "path": "demo_images/demo_img_1.png"},
        {"type": "text", "text": "Describe this image."}
    ]
}]
gpt_conv2 = [{
    "role": "user",
    "content": [
        {"type": "image", "path": "demo_images/demo_img_2.png"},
        {"type": "text", "text": "Describe this image for me. Provide a detailed description of the image."}
    ]
}]

messages = [gpt_conv1, gpt_conv2]
texts = [
    processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
    for msg in messages
]
inputs = processor(texts)

output_ids = model.generate(
    input_ids=inputs.input_ids,
    media=inputs.media,
    media_config=inputs.media_config,
    generation_config=model.generation_config,
    max_new_tokens=256,
)
output_texts = processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
print(output_texts[0])
print("---" * 40)
print(output_texts[1])
Ligeng-Zhu changed discussion status to closed
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment