reach-vb HF Staff commited on
Commit
841e9be
·
verified ·
1 Parent(s): 98098d3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -0
README.md CHANGED
@@ -51,6 +51,52 @@ python predict.py --model-path /path/to/checkpoint-dir \
51
  --prompt "Describe the image."
52
  ```
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  ## Citation
56
  If you found this model useful, please cite the following paper:
 
51
  --prompt "Describe the image."
52
  ```
53
 
54
+ ### Run inference with Transformers (Remote Code)
55
+ To run inference with transformers we can leverage `trust_remote_code` along with the following snippet:
56
+
57
+ ```python
58
+ import torch
59
+ from PIL import Image
60
+ from transformers import AutoTokenizer, AutoModelForCausalLM
61
+ MID = "apple/FastVLM-0.5B"
62
+ IMAGE_TOKEN_INDEX = -200 # what the model code looks for
63
+ # Load
64
+ tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ MID,
67
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
68
+ device_map="auto",
69
+ trust_remote_code=True,
70
+ )
71
+ # Build chat -> render to string (not tokens) so we can place <image> exactly
72
+ messages = [
73
+ {"role": "user", "content": "<image>\nDescribe this image in detail."}
74
+ ]
75
+ rendered = tok.apply_chat_template(
76
+ messages, add_generation_prompt=True, tokenize=False
77
+ )
78
+ pre, post = rendered.split("<image>", 1)
79
+ # Tokenize the text *around* the image token (no extra specials!)
80
+ pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
81
+ post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
82
+ # Splice in the IMAGE token id (-200) at the placeholder position
83
+ img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
84
+ input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
85
+ attention_mask = torch.ones_like(input_ids, device=model.device)
86
+ # Preprocess image via the model's own processor
87
+ img = Image.open("test-2.jpg").convert("RGB")
88
+ px = model.get_vision_tower().image_processor(images=img, return_tensors="pt")["pixel_values"]
89
+ px = px.to(model.device, dtype=model.dtype)
90
+ # Generate
91
+ with torch.no_grad():
92
+ out = model.generate(
93
+ inputs=input_ids,
94
+ attention_mask=attention_mask,
95
+ images=px,
96
+ max_new_tokens=128,
97
+ )
98
+ print(tok.decode(out[0], skip_special_tokens=True))
99
+ ```
100
 
101
  ## Citation
102
  If you found this model useful, please cite the following paper: